目录
2021年,软研亚洲究院在ICCV上发表 SwinTransformer文章,并获得ICCV 2021best paper等荣誉。 文中指出SwinTransformer可作为计算机视觉的通用backbone,在广泛的视觉任务如图像分类、目标检测、语义分割等领域都展现了巨大的的潜力,在多项视觉任务中霸榜。SwinTransformer的设计借鉴了ViT对图像数据的处理方法,通过分层设计结合多个等级的窗口划分降低了计算复杂度,提出位移窗口(Shifted windows)使相邻的窗口之间进行交互,从而达到全局建模的能力。下图所示为SwinTransformer与ViT模型的对比:
Swin Transformer使用了类似卷积神经网络中的层次化构建方法(Hierarchical feature maps),比如特征图尺寸中有对图像下采样4倍的,8倍的以及16倍的,这样的设计有助于网络提取更高级的特征,更适合构建目标检测,实例分割等任务。而在之前的Vision Transformer中是一开始就直接下采样16倍,后面的特征图也是维持这个下采样率不变。
层级式结构的好处在于不仅灵活的提供各种尺度的信息,同时还因为自注意力是在窗口内计算的,所以它的计算复杂度随着图片大小线性增长而不是平方级增长,这就使Swin Transformer能够在特别大的分辨率上进行预训练模型,并且通过多尺度的划分,使得Swin Transformer能够提取到多尺度的特征。相比之下,先前的ViT产生单一低分辨率的特征图,并且由于全局自注意力的计算,其计算复杂度是输入图像大小的二次方。
论文名称:《Swin Transformer: Hierarchical Vision Transformer using Shifted Windows》
论文地址:https://arxiv.org/abs/2103.14030
官方开源代码地址:https://github.com/microsoft/Swin-Transformer
Swin Transformer主要通过Windows Multi-Head Self-Attention(W-MSA)将特征图换分为多个不相交的窗口(Window),每个窗口里都有m*m个patch(原文中m=7,此时每个窗口有49个patch),自注意力计算都是分别在窗口内完成的,所以序列长度永远都是m*m(即49)。W-MSA很好的解决了内存和计算量的问题,但是窗口与窗口之间没有了信息交互,没能达到全局建模的效果,这就限制了模型的能力。因此,作者又提出了Shifted Windows Multi-Head Self-Attention(SW-MSA),通过窗口的滑动方法,使相邻的窗口之间进行交互,从而达到全局建模的能力。
上图为论文中给出的Swin Transformer(Swin-T)网络的架构图。
1)首先将图片输入到Patch Partition模块中进行分块,将输入的RGB图像分割成不重叠的patch。
原文中每4x4相邻的像素为一个Patch,然后在channel方向展平(flatten)。假设输入的是RGB三通道图片,那么每个patch就有4x4=16个像素,然后每个像素有R、G、B三个值所以展平后是4×4x3=48,所以通过Patch Partition后图像shape由 [H, W, 3]变成了 [H/4, W/4, 48]。然后在通过Linear Embeding层对每个像素的channel数据做线性变换,由48变成C,即图像shape再由 [H/4, W/4, 48]变成了 [H/4, W/4, C]。其实在源码中Patch Partition和Linear Embeding就是直接通过一个卷积层实现的,和之前Vision Transformer中讲的 Embedding层结构一模一样。
2)然后就是通过四个Stage构建不同大小的特征图,除了Stage1中先通过一个Linear Embeding层外,剩下三个stage都是先通过一个Patch Merging层进行下采样。然后都是重复堆叠Swin Transformer Block注意这里的Block其实有两种结构,如图(b)中所示,这两种结构的不同之处仅在于一个使用了W-MSA结构,一个使用了SW-MSA结构。而且这两个结构是成对使用的,先使用一个W-MSA结构再使用一个SW-MSA结构。所以你会发现堆叠Swin Transformer Block的次数都是偶数(因为成对使用)。
3) 最后对于分类网络,后面还会接上一个Layer Norm层、全局池化层以及全连接层得到最终输出。(图中没有画)
在Swin Transformer结构中,每个Stage中首先要通过一个Patch Merging层进行下采样(Stage1除外)。如下图所示,假设输入Patch Merging的是一个4x4大小的单通道特征图(feature map),Patch Merging会将每个2x2的相邻像素划分为一个patch,然后将每个patch中相同位置(同一颜色)像素给拼在一起就得到了4个feature map。接着将这四个feature map在深度方向进行concat拼接,然后在通过一个LayerNorm层。最后通过一个全连接层在feature map的深度方向做线性变化,将feature map的深度由C变成C/2。通过这个简单的例子可以看出,通过Patch Merging层后,feature map的高和宽会减半,深度会翻倍。
Windows Multi-head Self-Attention(W-MSA)模块可以减少计算量。如下图所示,左侧使用的是普通的Multi-head Self-Attention(MSA)模块,对于feature map中的每个像素(或称作token,patch),在Self-Attention计算过程中需要和所有的像素去计算。但在图右侧,在使用Windows Multi-head Self-Attention(W-MSA)模块时,首先将feature map按照MxM(例子中的M=2)大小划分成一个个Windows,然后单独对每个Windows内部进行Self-Attention。
原文中,给出了两个计算复杂度的表达式:
假设feature map的h、w为64,M=4,C=96。
采用MSA模块的计算复杂度为 4x64x64x962 +2x(64x64)2x96 = 3372220416
采用W-MSA模块的计算复杂度为 4x64x64x962 +2x42x64x64x96 = 163577856
3372220416-163577856=3,208,642,560约节省了95%的的 FLOPs(计算复杂度)。
Swin Transformer中提出的W-MSA模块,只会在每个窗口内进行自注意力计算,所以窗口与窗口之间是无法进行信息传递的。为了解决这个问题,作者引入了Shifted Windows Multi-Head Self-Attention(SW-MSA)模块,即进行偏移的W-MSA。如下图所示,左侧使用的是刚刚讲的W-MSA(假设是第L1层),那么根据之前介绍的W-MSA和SW-MSA是成对使用的,那么第L1+1层使用的就是SW-MSA(右侧图)。根据左右两幅图对比能够发现窗口(Windows)发生了偏移(可以理解成窗口从左上角分别向右侧和下方各偏移了⌊ ⌋个像素)。在偏移后,能够使多个窗口之间进行信息交流,如A窗口包含了1和2区域两个窗口的信息,C窗口包含了1、2、3、4四个窗口的信息,在偏移后的窗口上进行计算,就解决了不同窗口之间无法进行信息交流的问题。
但在偏移后的窗口中,从原始layer1变为layer1+1,窗口数目从4个变为了9个,且大小也不完全相同,在对每个窗口内部进行MSA计算时,导致计算难度增加。因此,作者在文中又提出了Efficient batch computation for shifted configuration,通过将窗口进行位移,变为四个大窗口进行计算,下图是原文中的示意图:
其主要的思想是将最上方的三个窗口移动至最下方,再将最左侧的三个窗口移动至最右侧,最后将小的窗口进行合并,形成四个4×4的大窗口,使计算量与初始窗口保持一致。如下图所示,先将123区域移动至最下方,再将471区域移动至最右方,最终将64、82、9731进行合并为三个4×4的整体区域,完成了窗口的缩减。
但合成后的区域中,区域的信息并不相连,此时在该窗口内进行MSA操做会导致信息的杂乱结合,因此作者在文中又提出了MsakMSA即带蒙板mask的MSA,这样就能够通过设置蒙板来隔绝不同区域的信息了。其思想为:在Attention计算中,最后经过softmax操作时,当输入的值很小时,其输出几乎为0,Attention计算公式如下:
以上图为例,在6和4区域的合并窗口内进行MSA计算时,当像素是属于区域6的,那么我们可以将区域3中的所有像素做注意力计算后的结果都减去100,使其在通过softmax后,对应的权重几乎为0,达到该区域内的像素只与该区域进行匹配。在计算结束后,还需将数据移动会原来的位置上。
论文中指出,在使用了相对位置偏置后,模型在Imagenet数据集上的Top-1从80.1提升至81.3,提升的效果还是很明显的。在添加相对位置偏执后的Attention计算公式如下:
(其中B为relative position bias)
论文中并没有对相对位置偏置进行详细的详解,从源码中可以看出其主要思想如下:
首先我们可以构建出每个像素的绝对位置(左下方的矩阵),对于每个像素的绝对位置是使用行号和列号表示的。比如蓝色的像素对应的是第0行第0列所以绝对位置索引是( 0 , 0 ) ,接下来再看看相对位置索引。以蓝色的像素举例, 用蓝色像素的绝对位置索引与其他位置索引进行相减,就得到其他位置相对蓝色像素的相对位置索引。例如黄色像素的绝对位置索引是( 0 , 1 ) ,则它相对蓝色像素的相对位置索引为( 0 , 0 ) − ( 0 , 1 ) = ( 0 , − 1 ) 。那么同理可以得到其他位置相对蓝色像素的相对位置索引矩阵。同样,也能得到相对黄色,红色以及绿色像素的相对位置索引矩阵。接下来将每个相对位置索引矩阵按行展平,并拼接在一起可以得到下面的4x4矩阵 。
上述计算得到的是相对位置索引,并不是相对位置偏置参数。因为后面我们会根据相对位置索引去取对应的参数,如下图所示。比如说黄色像素是在蓝色像素的右边,所以相对蓝色像素的相对位置索引为( 0 , − 1 )。绿色像素是在红色像素的右边,所以相对红色像素的相对位置索引为( 0 , − 1 ) 。可以发现这两者的相对位置索引都是( 0 , − 1 ),所以他们使用的相对位置偏置参数都是一样的。
在源码中作者为了方便把二维索引给转成了一维索引,但其原理并不是简单的进行加减。
1)首先在原始的相对位置索引上加上M-1(M为窗口的大小,在本示例中M=2),加上之后索引中就不会有负数了。
2)将所有的行标都乘上2M-1。
3)将行标和列标进行相加。这样即保证了相对位置关系,而且不会出现上述0+(-1)=(-1)+0的问题了。
在上述计算后,仍是得到的相对位置索引,并不是相对位置偏置,真正使用到的可训练参数B是保存在relative position bias table表里的,这个表的长度是等于(2M-1)×(2M-1)的。那么上述公式中的相对位置偏执参数B是根据上面的相对位置索引表根据查relative position bias table表得到的,如下图所示。
论文中给出了关于不同Swin Transformer的配置,T(Tiny),S(Small),B(Base),L(Large),其中:
1. 深度学习之图像分类(十三): Swin Transformer - 魔法学院小学弟
2. Swin-Transformer网络结构详解_swin transformer_太阳花的小绿豆的博客-CSDN博客