本文的目的主要在于改进自注意力计算的高昂计算成本。所以基于局部自注意力的形式进行了扩展,实现了一种更加高效的全局注意力形式,而免去了Swin那样的划窗操作(划窗操作需要进行padding和mask,以及划窗仅仅会覆盖不同局部区域的部分内容)或者其他更为复杂的例如token unfolding和rolling操作,甚至是对于key和value的额外计算。
仍然基于windows-based attention的形式,从而保证了相对于图像大小线性的放缩关系。
在类似于Swin中的local attention的基础上,作者们构建了一种新的global attention形式,来实现横跨不同local window之间的图像patch上的信息交流。
global attention的核心是对原本local attention的query的改进。其直接使用从原始图像特征上利用CNN结构(即文中的Global Query Generator)提取缩小到窗口区域对应尺寸嵌入,使其与image token窗口中的local key和local value进行计算,从而允许捕获跨区域交互的长距离信息。
因此local attention与global attention的唯一差异在于query的来源,前者来自模块输入,而后者则来自于stage初生成,内部各个module共享的global token。
整个stage的流程如下:
GCViTLayer
对应的代码段如下:
# https://github.com/NVlabs/GCVit/blob/caa62fc4d55cf822cf3bef5eb8b69cc11b90e885/models/gc_vit.py#L525-L589
def forward(self, x):
q_global = self.q_global_gen(_to_channel_first(x))
for blk in self.blocks:
x = blk(x, q_global)
if self.downsample is None:
return x
return self.downsample(x)