京东AI研究院提出的一种新的注意力结构。将CoT Block代替了ResNet结构中的3x3卷积,在分类检测分割等任务效果都出类拔萃
论文:Contextual Transformer Networks for Visual Recognition
论文地址:https://arxiv.org/abs/2107.12292
有自注意力的Transformer引发了自然语言处理领域的革命,最近还激发了Transformer式架构设计的出现,并在众多计算机视觉任务中取得了具有竞争力的结果。
大多数现有设计直接在2D特征图上使用自注意力来获得基于每个空间位置的独立查询和键对的注意力矩阵,但未充分利用相邻键之间的丰富上下文。在今天分享的工作中,研究者设计了一个新颖的Transformer风格的模块,即Contextual Transformer (CoT)块,用于视觉识别。这种设计充分利用输入键之间的上下文信息来指导动态注意力矩阵的学习,从而增强视觉表示能力。从技术上讲,CoT块首先通过3×3卷积对输入键进行上下文编码,从而产生输入的静态上下文表示。
上图a是传统的self-attention仅利用孤立的查询-键对来测量注意力矩阵,但未充分利用键之间的丰富上下文。 b就是CoT块
研究者进一步将编码的键与输入查询连接起来,通过两个连续的1×1卷积来学习动态多头注意力矩阵。学习到的注意力矩阵乘以输入值以实现输入的动态上下文表示。静态和动态上下文表示的融合最终作为输出。CoT块很吸引人,因为它可以轻松替换ResNet架构中的每个3 × 3卷积,产生一个名为Contextual Transformer Networks (CoTNet)的Transformer式主干。通过对广泛应用(例如图像识别、对象检测和实例分割)的大量实验,验证了CoTNet作为更强大的主干的优越性
Attention注意力机制与self-attention自注意力机制
为什么要注意力机制?
在Attention诞生之前,已经有CNN和RNN及其变体模型了,那为什么还要引入attention机制?主要有两个方面的原因,如下:
(1)计算能力的限制:当要记住很多“信息“,模型就要变得更复杂,然而目前计算能力依然是限制神经网络发展的瓶颈。
(2)优化算法的限制:LSTM只能在一定程度上缓解RNN中的长距离依赖问题,且信息“记忆”能力并不高。
什么是注意力机制
在介绍什么是注意力机制之前,先让大家看一张图片。当大家看到下面图片,会首先看到什么内容?当过载信息映入眼帘时,我们的大脑会把注意力放在主要的信息上,这就是大脑的注意力机制。
同样,当我们读一句话时,大脑也会首先记住重要的词汇,这样就可以把注意力机制应用到自然语言处理任务中,于是人们就通过借助人脑处理信息过载的方式,提出了Attention机制。
self attention是注意力机制中的一种,也是transformer中的重要组成部分。自注意力机制是注意力机制的变体,其减少了对外部信息的依赖,更擅长捕捉数据或特征的内部相关性。自注意力机制在文本中的应用,主要是通过计算单词间的互相影响,来解决长距离依赖问题。
传统的自注意力很好地触发了不同空间位置的特征交互,具体取决于输入本身。然而,在传统的自注意力机制中,所有成对的查询键关系都是通过孤立的查询键对独立学习的,而无需探索其间的丰富上下文。这严重限制了自注意力学习在2D特征图上进行视觉表示学习的能力。

为了缓解这个问题,研究者构建了一个新的Transformer风格的构建块,即上图 (b)中的 Contextual Transformer (CoT) 块,它将上下文信息挖掘和自注意力学习集成到一个统一的架构中。

- # YOLOv7 🚀, GPL-3.0 license
- # parameters
- nc: 80 # number of classes
- depth_multiple: 0.33 # model depth multiple
- width_multiple: 1.0 # layer channel multiple
-
- # anchors
- anchors:
- - [12,16, 19,36, 40,28] # P3/8
- - [36,75, 76,55, 72,146] # P4/16
- - [142,110, 192,243, 459,401] # P5/32
-
- # yolov7 backbone by yoloair
- backbone:
- # [from, number, module, args]
- [[-1, 1, Conv, [32, 3, 1]], # 0
- [-1, 1, Conv, [64, 3, 2]], # 1-P1/2
- [-1, 1, Conv, [64, 3, 1]],
- [-1, 1, Conv, [128, 3, 2]], # 3-P2/4
- [-1, 1, C3HB, [128]],
- [-1, 1, Conv, [256, 3, 2]],
- [-1, 1, MP, []],
- [-1, 1, Conv, [128, 1, 1]],
- [-3, 1, Conv, [128, 1, 1]],
- [-1, 1, Conv, [128, 3, 2]],
- [[-1, -3], 1, Concat, [1]], # 16-P3/8
- [-1, 1, Conv, [128, 1, 1]],
- [-2, 1, Conv, [128, 1, 1]],
- [-1, 1, Conv, [128, 3, 1]],
- [-1, 1, Conv, [128, 3, 1]],
- [-1, 1, Conv, [128, 3, 1]],
- [-1, 1, Conv, [128, 3, 1]],
- [[-1, -3, -5, -6], 1, Concat, [1]],
- [-1, 1, Conv, [512, 1, 1]],
- [-1, 1, MP, []],
- [-1, 1, Conv, [256, 1, 1]],
- [-3, 1, Conv, [256, 1, 1]],
- [-1, 1, Conv, [256, 3, 2]],
- [[-1, -3], 1, Concat, [1]],
- [-1, 1, Conv, [256, 1, 1]],
- [-2, 1, Conv, [256, 1, 1]],
- [-1, 1, Conv, [256, 3, 1]],
- [-1, 1, Conv, [256, 3, 1]],
- [-1, 1, Conv, [256, 3, 1]],
- [-1, 1, Conv, [256, 3, 1]],
- [[-1, -3, -5, -6], 1, Concat, [1]],
- [-1, 1, Conv, [1024, 1, 1]],
- [-1, 1, MP, []],
- [-1, 1, Conv, [512, 1, 1]],
- [-3, 1, Conv, [512, 1, 1]],
- [-1, 1, Conv, [512, 3, 2]],
- [[-1, -3], 1, Concat, [1]],
- [-1, 1, C3HB, [1024]],
- [-1, 1, Conv, [256, 3, 1]],
- ]
-
- # yolov7 head by yoloair
- head:
- [[-1, 1, SPPCSPC, [512]],
- [-1, 1, Conv, [256, 1, 1]],
- [-1, 1, nn.Upsample, [None, 2, 'nearest']],
- [31, 1, Conv, [256, 1, 1]],
- [[-1, -2], 1, Concat, [1]],
- [-1, 1, CoT3, [128]],
- [-1, 1, Conv, [128, 1, 1]],
- [-1, 1, nn.Upsample, [None, 2, 'nearest']],
- [18, 1, Conv, [128, 1, 1]],
- [[-1, -2], 1, Concat, [1]],
- [-1, 1, CoT3, [128]],
- [-1, 1, MP, []],
- [-1, 1, Conv, [128, 1, 1]],
- [-3, 1, Conv, [128, 1, 1]],
- [-1, 1, Conv, [128, 3, 2]],
- [[-1, -3, 44], 1, Concat, [1]],
- [-1, 1, CoT3, [256]],
- [-1, 1, MP, []],
- [-1, 1, Conv, [256, 1, 1]],
- [-3, 1, Conv, [256, 1, 1]],
- [-1, 1, Conv, [256, 3, 2]],
- [[-1, -3, 39], 1, Concat, [1]],
- [-1, 3, CoT3, [512]],
-
- # 检测头 -----------------------------
- [49, 1, RepConv, [256, 3, 1]],
- [55, 1, RepConv, [512, 3, 1]],
- [61, 1, RepConv, [1024, 3, 1]],
-
- [[62,63,64], 1, IDetect, [nc, anchors]], # Detect(P3, P4, P5)
- ]
-
./models/common.py文件增加以下模块
- class CoT3(nn.Module):
- def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
- super().__init__()
- c_ = int(c2 * e) # hidden channels
- self.cv1 = Conv(c1, c_, 1, 1)
- self.cv2 = Conv(c1, c_, 1, 1)
- self.cv3 = Conv(2 * c_, c2, 1) # act=FReLU(c2)
- self.m = nn.Sequential(*[CoTBottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
- # self.m = nn.Sequential(*[CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)])
-
- def forward(self, x):
- return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))
-
- class CoTBottleneck(nn.Module):
- def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
- super(CoTBottleneck, self).__init__()
- c_ = int(c2 * e) # hidden channels
- self.cv1 = Conv(c1, c_, 1, 1)
- self.cv2 = CoT(c_, 3)
- self.add = shortcut and c1 == c2
-
- def forward(self, x):
- return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
-
- class CoT(nn.Module):
- # Contextual Transformer Networks https://arxiv.org/abs/2107.12292
- def __init__(self, dim=512,kernel_size=3):
- super().__init__()
- self.dim=dim
- self.kernel_size=kernel_size
-
- self.key_embed=nn.Sequential(
- nn.Conv2d(dim,dim,kernel_size=kernel_size,padding=kernel_size//2,groups=4,bias=False),
- nn.BatchNorm2d(dim),
- nn.ReLU()
- )
- self.value_embed=nn.Sequential(
- nn.Conv2d(dim,dim,1,bias=False),
- nn.BatchNorm2d(dim)
- )
-
- factor=4
- self.attention_embed=nn.Sequential(
- nn.Conv2d(2*dim,2*dim//factor,1,bias=False),
- nn.BatchNorm2d(2*dim//factor),
- nn.ReLU(),
- nn.Conv2d(2*dim//factor,kernel_size*kernel_size*dim,1)
- )
-
-
- def forward(self, x):
- bs,c,h,w=x.shape
- k1=self.key_embed(x) #bs,c,h,w
- v=self.value_embed(x).view(bs,c,-1) #bs,c,h,w
-
- y=torch.cat([k1,x],dim=1) #bs,2c,h,w
- att=self.attention_embed(y) #bs,c*k*k,h,w
- att=att.reshape(bs,c,self.kernel_size*self.kernel_size,h,w)
- att=att.mean(2,keepdim=False).view(bs,c,-1) #bs,c,h*w
- k2=F.softmax(att,dim=-1)*v
- k2=k2.view(bs,c,h,w)
-
- return k1+k2
-
然后找到./models/yolo.py文件下里的parse_model函数,将加入的模块名CoT3加入进去
在 models/yolo.py文件夹下
定位到parse_model函数中
for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']):内部
对应位置 下方只需要增加 CoT3模块
修改完成!