• 【Swin Transformer原理和源码解析】Hierarchical Vision Transformer using Shifted Windows


    前言

    ViT让Transformer第一次在视觉任务中暂露头角,而Swin Transfomer直接让Transformer在视觉任务中大放光彩,直接打败了当时的所有的CNN网络,一出来就直接是当时的Sota。现在的很多厉害的Transfomer变体都是Swin改进的,而且Swin Transformer这个网络在很多比赛上都会用它,分类、分割、检测基本上用它都不会差,我打的一个分类比赛就是用的它: 【记第一次kaggle比赛】PetFinder.my - Pawpularity Contest 宠物预测。当时打的时候是掉包的,两句话就创建了Model了,知其然不知所以然,这怎么行,所以今天有必要学习一下。

    论文地址: https://arxiv.org/pdf/2103.14030.pdf

    源码地址: https://github.com/microsoft/Swin-Transformer

    这里我用的是b站大佬 霹雳吧啦Wz 改编后(相对源码作了微小改动,增加了多尺度训练)的代码:
    WZMIAOMIAO

    注释版本代码也同样分享到了我的Github:https://github.com/HuKai97/Classification-Annotations

    一、动机和改进点

    VIT为了让图像可以像词向量那样输入Encoder中,而且计算量还不能太大,就直接将图像切分成一个个小的Patch,再把每个Patch当成一个词向量,把所有Patch拼接起来送入Encoder,这样当然可以降低参数量和计算量,但是当图像变大,Patch数目变多,复杂度太大。还有没有更好的输入方式了呢?

    VIT主要是改变了一下图片的输入,让Transformer的Encoder可以适用于图像任务中,但是对于整个模型的架构(之前讲LN提前了),VIT是没有做什么改进的,用的还是原始的Transformer中的Encoder(整个Encoder内部各个encoder变换,但是特征的shape是不变的)。那么原始的Transformer的Encoder模块真的就适用于图像任务吗,还有没有更好的Encoder结构?

    所以总结下,ViT有两个问题:

    1. 尺度问题,数据集物体大大小小,但是整个Encoder过程特征尺度是不变的,效果肯定不好;
    2. 划分patch,再把整张图片的所有patch都输入Encoder中,计算量太大;

    所以,Swin Transformer针对这两点做出了改进:

    1. Encode呈现金字塔形状。每过一个Encode图片shape变小,感受野在不停的增大,解决了尺度问题。
    2. 注意力机制放在一个窗口内部。不再把整张图片的所有patch都输入Encoder,而是将各个Patch单独的输入Encoder,解决了计算量太大的问题。

    二、整体架构:SwinTransformer

    在这里插入图片描述

    1. Patch Embeded:对输入图片 [bs,3,H_,W_] 进行处理。第一步:先经过Patch Partition,将图像划分为一个个的patch,每个patch是4x4x3大小(4x4Conv实现)得到一个 [bs,48,H_/4, W_/4] 大小的特征图;第二步:经过一个Linear Embedding层,进行Linear线性变换,得到 [bs, H_/4 * W_/4, C=96];(但是实际代码是通过一个4x4Conv s=4实现的,其实本质还是在学习参数,一样的)
    2. 经过4个stage:每个stage是若干个Swin Transformer Block + Patch Merging。前者计算相关性,后者进行采样,实现多尺度;最终经过4个stage后,特征下采样为 [bs,H_/32 * W_/32,8C=768];
    3. 分类:经过一个avgpool+flatten+Linear进行分类预测,最终得到 [bs,num_classes];

    源码:

    class SwinTransformer(nn.Module):
        r""" Swin Transformer
            A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -
              https://arxiv.org/pdf/2103.14030
        """
        def __init__(self, patch_size=4, in_chans=3, num_classes=1000,
                     embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24),
                     window_size=7, mlp_ratio=4., qkv_bias=True,
                     drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
                     norm_layer=nn.LayerNorm, patch_norm=True,
                     use_checkpoint=False, **kwargs):
            """
            patch_size: 每个patch的大小  4x4
            in_chans: 输入图像的通道数 3
            num_classes: 分类类别数 默认1000
            embed_dim: 通过Linear Embedding后映射得到的通道数 也就是图片中的C 默认96
            depths: 每个stage中重复swin-transformer block的次数 默认(2, 2, 6, 2)
            num_heads: 每个stage中swin-transformer block的muti-head的个数 默认(3, 6, 12, 24)
            window_size: 滑动窗口的大小 默认7x7
            mlp_ratio: MLP中第一个全连接层Linear会将channel翻多少倍 默认4倍
            qkv_bias: 在muti-head self-attention中是否使用偏置 默认使用True
            drop_rate:
            attn_drop_rate: 在muti-head self-attention中使用的drop rate
            drop_path_rate: 在每个swin-transformer block中使用的drop rate  从0慢慢增加到0.1
            norm_layer: LN
            patch_norm:
            use_checkpoint: 使用可以节省内存 默认不使用
            """
            super().__init__()
    
            self.num_classes = num_classes   # 5
            self.num_layers = len(depths)    # 4
            self.embed_dim = embed_dim       # C = 96
            self.patch_norm = patch_norm     # True
            # stage4输出特征矩阵的channels
            self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))   # 768 = 8C
            self.mlp_ratio = mlp_ratio       # 4.0
    
            # split image into non-overlapping patches
            self.patch_embed = PatchEmbed(
                patch_size=patch_size, in_c=in_chans, embed_dim=embed_dim,
                norm_layer=norm_layer if self.patch_norm else None)
            self.pos_drop = nn.Dropout(p=drop_rate)  # p=0
    
            # stochastic depth
            # [0.0, 0.00909090880304575, 0.0181818176060915, 0.027272727340459824, 0.036363635212183, 0.045454543083906174, 0.054545458406209946, 0.06363636255264282, 0.0727272778749466, 0.08181818574666977, 0.09090909361839294, 0.10000000149011612]
            dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule
    
            # build layers/stages   4个
            self.layers = nn.ModuleList()
            for i_layer in range(self.num_layers):
                # 注意这里构建的stage和论文图中有些差异
                # 这里的stage不包含该stage的patch_merging层,包含的是下个stage的
                # stage1-3: Swin Transformer Block + Patch Merging
                # Stage4: Swin Transformer Block
                layers = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
                                    depth=depths[i_layer],
                                    num_heads=num_heads[i_layer],
                                    window_size=window_size,
                                    mlp_ratio=self.mlp_ratio,
                                    qkv_bias=qkv_bias,
                                    drop=drop_rate,
                                    attn_drop=attn_drop_rate,
                                    drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                                    norm_layer=norm_layer,
                                    downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
                                    use_checkpoint=use_checkpoint)
                self.layers.append(layers)
    
            self.norm = norm_layer(self.num_features)   # LN(768)
            self.avgpool = nn.AdaptiveAvgPool1d(1)
            self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()  # 分类头  768 -> 5
    
            self.apply(self._init_weights)   # 初始化
    
        def _init_weights(self, m):
            if isinstance(m, nn.Linear):
                nn.init.trunc_normal_(m.weight, std=.02)
                if isinstance(m, nn.Linear) and m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1.0)
    
        def forward(self, x):
            """
            x: [bs, 3, H_, W_]
            """
            # 1、Patch Partition + Linear Embedding
            # [bs, 3, H_, W_] -> [bs, H_/4 * W_/4, C] -> [bs, H_/4 * W_/4, C]   C=96
            x, H, W = self.patch_embed(x)  # H = H_/4  W = W_/4
            x = self.pos_drop(x)
    
            # 2、4 stage = 4 x (Swin Transformer Block x n + Patch Merging)
            # x: [bs, H_/4 * W_/4, C] -> [bs, H_/8 * W_/8, 2C] -> [bs, H_/16 * W_/16, 4C] -> [bs, H_/32 * W_/32, 8C]
            for layer in self.layers:
                x, H, W = layer(x, H, W)
            
            # 3、分类
            x = self.norm(x)  # LN(8C=768)
            x = self.avgpool(x.transpose(1, 2))  # [bs, H_/32 * W_/32, 8C] -> [bs, 8C, H_/32 * W_/32] -> [bs, 8C, 1]
            x = torch.flatten(x, 1)  # [bs, 8C, 1] -> [bs, 8C]
            x = self.head(x)  # [bs, num_classes]
            return x
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104

    三、输入设置:PatchEmbed

    源码和论文有出入,这里直接使用一个4x4Conv s=4,实现下采样的过程。对输入图片 [bs,3,H_,W_]进行初步处理,得到一个[bs, H_/4 * W_/4, C=96]大小的特征图。源码如下:

    class PatchEmbed(nn.Module):
        """
        2D Image to Patch Embedding  [bs, 3, H_, W_] -> [B, H_/4 * W_/4, C=96]
        """
        def __init__(self, patch_size=4, in_c=3, embed_dim=96, norm_layer=None):
            """
            patch_size: 每个patch的大小 4x4
            in_c: 输入图像的channel 3
            embed_dim: 96 = C
            norm_layer: LN
            """
            super().__init__()
            patch_size = (patch_size, patch_size)
            self.patch_size = patch_size
            self.in_chans = in_c
            self.embed_dim = embed_dim
            self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)  # 4x4Conv 下采样4倍 c:3->96
            self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
    
        def forward(self, x):
            # x: [bs, 3, H_, W_]
            _, _, H, W = x.shape
    
            # padding
            # 如果输入图片的H,W不是patch_size的整数倍,需要进行padding
            pad_input = (H % self.patch_size[0] != 0) or (W % self.patch_size[1] != 0)  # False
            if pad_input:
                # to pad the last 3 dimensions,
                # (W_left, W_right, H_top,H_bottom, C_front, C_back)
                x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1],
                              0, self.patch_size[0] - H % self.patch_size[0],
                              0, 0))
    
            # 1、Patch Partition
            # 下采样patch_size倍  [bs, 3, H_, W_] -> [bs, C=96, H_/4, W_/4]
            x = self.proj(x)
            _, _, H, W = x.shape  # H=H_/4  W=W_/4
            # flatten: [B, C, H_/4, W_/4] -> [B, C, H_/4 * W_/4]
            # transpose: [B, C, H_/4 * W_/4] -> [B, H_/4 * W_/4, C]
            x = x.flatten(2).transpose(1, 2)
            x = self.norm(x)
            return x, H, W
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42

    四、4个重复的Stage:BasicLayer

    每个stage都由若干个Swin Transformer Block 和 1个Patch Merging组成。

    class BasicLayer(nn.Module):
        """A basic Swin Transformer layer for one stage."""
        def __init__(self, dim, depth, num_heads, window_size,
                     mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
                     drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
            """
            dim: C = 96
            depth: 重叠的Swin Transformer Block个数
            num_heads: muti-head self-transformer的头数
            window_size: 窗口大小7x7
            mlp_ratio: MLP中第一个全连接层Linear会将channel翻多少倍 默认4倍
            qkv_bias: 在muti-head self-attention中是否使用偏置 默认使用True
            drop: patch_embed之后一般要接一个Dropout 但是默认是 0.0
            attn_drop: 在muti-head self-attention中使用的drop rate  0.0
            drop_path: list: depth  存放这个stage中depth个transformer block的drop rate
            norm_layer: LN
            downsample: Pathc Merging进行下采样
            use_checkpoint: Whether to use checkpointing to save memory. Default: False
            """
            super().__init__()
            self.dim = dim
            self.depth = depth
            self.window_size = window_size
            self.use_checkpoint = use_checkpoint
            self.shift_size = window_size // 2  # 3
    
            # 调用depth个swin transformer block
            self.blocks = nn.ModuleList([
                SwinTransformerBlock(
                    dim=dim,
                    num_heads=num_heads,
                    window_size=window_size,
                    shift_size=0 if (i % 2 == 0) else self.shift_size,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    drop=drop,
                    attn_drop=attn_drop,
                    drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                    norm_layer=norm_layer)
                for i in range(depth)])
    
            # patch merging layer
            if downsample is not None:
                self.downsample = downsample(dim=dim, norm_layer=norm_layer)
            else:
                self.downsample = None
    
        def create_mask(self, x, H, W):
            ...
    
        def forward(self, x, H, W):
            # 1、depth个swin transformer block
            # 因为每个stage中的特征图大小是不变的,所以每个block的mask大小是相同的 所以只需要创建一次即可
            # [64,49,49]  64个网格  49x49每个网格中的每个位置(49个位置)对该网格中所有位置(49个位置)的注意力蒙版
            attn_mask = self.create_mask(x, H, W)  # [nW, Mh*Mw, Mh*Mw]
            for blk in self.blocks:
                blk.H, blk.W = H, W
                if not torch.jit.is_scripting() and self.use_checkpoint:
                    x = checkpoint.checkpoint(blk, x, attn_mask)
                else:
                    # 默认执行 调用swin transformer block
                    x = blk(x, attn_mask)
    
            # 2、下采样 Patch Merging
            # 最后一个stage是None 不执行下采样
            if self.downsample is not None:
                x = self.downsample(x, H, W)
                H, W = (H + 1) // 2, (W + 1) // 2  # 下采样 重新计算H W
    
            return x, H, W
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70

    值得注意的是创建attention mask(create_mask)的步骤,这一步是下面SW-MSA和W-MSA的关键点,下面再详细讲解。

    4.1、SwinTransformerBlock

    4.1.1、创建mask

    在SwinTransformerBlock中,主要是负责创建attention mask,只在shift windows muti-head attention中使用,主要是告诉我们当前位置和哪些其他位置是同属于一个windows的(因为之前有一步shift window的操作),同属于一个windows的位置的mask=0,不同属于一个位置的mask=-100。

    这样到后面计算出attention之后,同一个windows位置的attention + mask再softmax值是不变的,但是不同windows位置的attention + mask(-100),再softmax值就趋近于0了。

    class BasicLayer(nn.Module):
        """A basic Swin Transformer layer for one stage."""
        ...
        def create_mask(self, x, H, W):
            """calculate attention mask for SW-MSA(shift window muti-head self-attention)
            以第一个stage为例
            x: [bs, 56x56, 96]
            H: 56
            W: 56
            返回attn_mask: [64,49,49] 64个网格  49x49每个网格中的每个位置(49个位置)对该网格中所有位置(49个位置)的注意力蒙版
                         记录每个位置需要在哪些位置计算attention
            """
            # 保证Hp和Wp是window_size的整数倍
            Hp = int(np.ceil(H / self.window_size)) * self.window_size   # 56
            Wp = int(np.ceil(W / self.window_size)) * self.window_size   # 56
            # 拥有和feature map一样的通道排列顺序,方便后续window_partition
            img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)  # [1, 56, 56, 1]
            # 对h和w先进行切片 划分为3个区域  0=(0,-7) (-7,-3) (-3,-1)
            h_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            # 对3x3=9个区域进行划分 编号 0-8
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1
    
            # 将img_mask划分为一个个的窗口   64个7x7大小的窗口
            # [1,56,56,1] -> [64,7,7,1] -> [64,7,7]
            mask_windows = window_partition(img_mask, self.window_size)  # [nW, Mh, Mw, 1]
            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)  # [nW, Mh*Mw]
            # [nW, 1, Mh*Mw] - [nW, Mh*Mw, 1] -> [nW, Mh*Mw, Mh*Mw]=[64,49,49]
            # 数字相同的位置代表是同一个区域  我们就是要计算同一个区域的attention  相减之后为0的区域就是我们需要计算attention的地方
            # 64个网格  49x49每个网格中的每个位置(49个位置)对该网格中所有位置(49个位置)的注意力蒙版
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            # 对于非零区域填上-100  这些区域是不需要计算attention的  所以在之后的softmax后就会为0
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
            return attn_mask
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42

    这里涉及到划分窗口的操作:

    def window_partition(x, window_size: int):
        """
        将feature map按照window_size划分成一个个没有重叠的window
        Args:
            x: (B, H, W, C)
            window_size (int): window size(M)
    
        Returns:
            windows: (num_windows*B, window_size, window_size, C)
        """
        B, H, W, C = x.shape   # 1  56  56  1
        x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)  # [1,56,56,1] -> [1,8,7,8,7,1]
        # permute: [B, H//Mh, Mh, W//Mw, Mw, C] -> [B, H//Mh, W//Mh, Mw, Mw, C]
        # view: [B, H//Mh, W//Mw, Mh, Mw, C] -> [B*num_windows, Mh, Mw, C]
        windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)  # [1,8,7,8,7,1] -> [1,8,8,7,7,1] -> [64,7,7,1]
        return windows
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16

    4.1.2、shift特征

    class SwinTransformerBlock(nn.Module):
    	def forward(self, x, attn_mask):
    		# cyclic shift
            if self.shift_size > 0:  # SW-MSA
                # 对x特征进行移动  0-shift_size列移动到最右侧   0-shift_size行移动到最下面
                # -的就是从上往下 从左往右  +的就是从下往上 从右往左了
                # 对应的attn_mask就是传入的attn_mask
                shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
            else:  # W-MSA  不需要移动
                shifted_x = x
                attn_mask = None
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11

    最后计算完SW-MSA后需要将shift过的特征进行还原:

    		# 之前shift过windows 再还原  从下往上 从右往左 +
            if self.shift_size > 0:
                x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
            else:
                x = shifted_x
    
    • 1
    • 2
    • 3
    • 4
    • 5

    4.1.3、为shift后的特征划分窗口

    		# 为shifted_x划分窗口  与attn_mask划分的窗口对应  [bs,56,56,96] -> [512,7,7,96]  8x8xbs个7x7的窗口 x 96个通道
            x_windows = window_partition(shifted_x, self.window_size)  # [nW*B, Mh, Mw, C]
            x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # [nW*B, Mh*Mw, C]=[512,49,96]
    
    • 1
    • 2
    • 3

    这里的划分窗口和上面mask的划分窗口一样,就不赘述。

    4.1.4、W-MSA VS SW-MSA

    class WindowAttention(nn.Module):
        r"""W-MSA/SW-MSA
        Window based multi-head self attention (W-MSA) module with relative position bias.
        It supports both of shifted and non-shifted window.
        """
        def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.):
            """
            dim: C = 96
            window_size: 窗口大小7x7
            num_heads: muti-head self-transformer的头数
            qkv_bias: 在muti-head self-attention中是否使用偏置 默认使用True
            proj_drop: 在muti-head self-attention中使用的drop rate  0.0
            """
            super().__init__()
            self.dim = dim
            self.window_size = window_size  # [7, 7]
            self.num_heads = num_heads
            head_dim = dim // num_heads
            self.scale = head_dim ** -0.5
    
            # 初始化relative_position_bias_table
            self.relative_position_bias_table = nn.Parameter(
                torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # [2*7-1 * 2*7-1, num_heads]
    
            # 1、生成绝对位置坐标索引
            coords_h = torch.arange(self.window_size[0])  # tensor([0, 1, 2, 3, 4, 5, 6])
            coords_w = torch.arange(self.window_size[1])  # tensor([0, 1, 2, 3, 4, 5, 6])
            # coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij"))
            # [2, 7, 7]  7x7窗口的xy坐标
            coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
            # [2, 7, 7] -> [2, 49]  第一个是所有位置的行坐标  第二个是所有位置的列坐标
            coords_flatten = torch.flatten(coords, 1)
    
            # 2、生成相对位置坐标索引
            # [2, Mh*Mw, 1] - [2, 1, Mh*Mw] -> [2, Mh*Mw, Mh*Mw]
            relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
    
            # [2, Mh*Mw, Mh*Mw] -> [Mh*Mw, Mh*Mw, 2]
            relative_coords = relative_coords.permute(1, 2, 0).contiguous()
    
            # 3、将二元相对位置坐标索引转变成一元相对位置坐标索引
            # 原始相对位置行/列标 = -6~6 + (window_size-1) -> 0~12
            # 行标 + (2 * window_size - 1) -> 13~25
            # 这时直接把行标 + 列标 直接把2D索引转换为1D索引 就不会出现(-1,0) (0,-1) 相加都是-1 无法区分的情况了
            relative_coords[:, :, 0] += self.window_size[0] - 1      # 行标 + (window_size-1)
            relative_coords[:, :, 1] += self.window_size[1] - 1      # 列标 + (window_size-1)
            relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1  # 行标 + (2 * window_size - 1)
            # [Mh*Mw, Mh*Mw, 2] -> [Mh*Mw, Mh*Mw]   行标 + 列标   直接转换为1元索引  与relative_position_bias_table一一对应
            relative_position_index = relative_coords.sum(-1)
    
            # 把relative_position_index放到缓存中  因为relative_position_index是固定值  不会变的  不需要修改
            # 我们网络训练的其实是relative_position_bias_table中的参数  我们每次循环都从relative_position_bias_table中拿对应idx的值即可
            self.register_buffer("relative_position_index", relative_position_index)
    
            self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)   # 生成qkv  3倍dim = q+k+v
            self.attn_drop = nn.Dropout(attn_drop)             # p=0.0
            self.proj = nn.Linear(dim, dim)                    # linear
            self.proj_drop = nn.Dropout(proj_drop)             # linear dropout p=0
    
            nn.init.trunc_normal_(self.relative_position_bias_table, std=.02)  # 初始化relative_position_bias_table参数
    
            self.softmax = nn.Softmax(dim=-1)   # softmax层
    
        def forward(self, x, mask: Optional[torch.Tensor] = None):
            """
            x: [bsx8x8, 49, 96]  bsx  8x8个7x7大小的window size  x96channel
            mask: W-MSA和SW-MSA交替出现 None/[8x8,49,49]  记录8x8个7x7大小的window size  中 每个位置需要和哪些位置计算attention
                  =0的位置表示是需要计算attention的
            Attention(Q,K,V) = SoftMax(Q*K的转置/scale + B)*V
            """
            B_, N, C = x.shape  # batch_size*num_windows=bsx8x8, Mh*Mw=7x7, total_embed_dim=96
    
            # 生成qkv 和vit中的一样  和原始的transformer有区别  但是本质都是相同的 都是通过学习参数把输入的x映射到3个空间上
            # qkv(): -> [batch_size*num_windows, Mh*Mw, 3 * total_embed_dim]
            # reshape: -> [batch_size*num_windows, Mh*Mw, 3, num_heads, embed_dim_per_head]
            # permute: -> [3, batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head] = [3,bsx8x8,3,7x7,32]
            qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
    
            # 分别获得q k v
            # [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head] = [bsx8x8,3,7x7,32]
            q, k, v = qkv.unbind(0)  # make torchscript happy (cannot use tensor as tuple)
    
            # 这里是先缩放再乘以k的转置  其实是一样的
            # transpose: -> [batch_size*num_windows, num_heads, embed_dim_per_head, Mh*Mw]
            # @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, Mh*Mw]
            q = q * self.scale
            attn = (q @ k.transpose(-2, -1))
    
            # relative_position_bias_table.view: [Mh*Mw*Mh*Mw,nH] -> [Mh*Mw,Mh*Mw,nH]
            # 生成相对位置偏置:生成相对位置index + 去relative_position_bias_table中去取相应的可学习的bias参数
            relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
                self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)
            relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # [nH, Mh*Mw, Mh*Mw]
    
            # att + B
            attn = attn + relative_position_bias.unsqueeze(0)
    
            # softmax处理
            if mask is not None:
                # SW-MSA
                # mask: [nW, Mh*Mw, Mh*Mw]=[8x8,49,49]  记录8x8个7x7大小的window中每个位置需要和哪些位置计算attention
                #       =0的位置表示是需要计算attention的   不相同的区域位置是接近-100表示的
                nW = mask.shape[0]  # num_windows
                # attn.view: [batch_size, num_windows, num_heads, Mh*Mw, Mh*Mw]
                # mask.unsqueeze: [1, nW, 1, Mh*Mw, Mh*Mw]
                # 相同区域位置attn+0没有影响   不同区域位置attn+(-100)  再进行softmax   这个位置的attn就->0
                attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
                attn = attn.view(-1, self.num_heads, N, N)
                attn = self.softmax(attn)
            else:
                # W-MSA
                attn = self.softmax(attn)
    
            attn = self.attn_drop(attn)
    
            # attn * v
            # @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
            # transpose: -> [batch_size*num_windows, Mh*Mw, num_heads, embed_dim_per_head]
            # reshape: -> [batch_size*num_windows, Mh*Mw, total_embed_dim]
            x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
            x = self.proj(x)
            x = self.proj_drop(x)
            return x
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123

    这个步骤和ViT中的其实差不多,只不过ViT是计算每个位置和所有位置的attention,而WindowAttention是按照窗口来计算每个位置和当前windows内所有位置的attention,计算量更小。

    4.2、PatchMerging

    这部分主要功能就是进行下采样,操作:每个一个元素取一个像素,有点类似YOLOv5中的Focus层。最后将4个特征拼接起来,再接一个Linear缩放通道。

    在这里插入图片描述

    class PatchMerging(nn.Module):
        r""" Patch Merging Layer. 下采样
        输入[bs, H_/4 * W/4, C=96]  ->  输出[bs, H_/8 * W/8, 2C] 
        """
        def __init__(self, dim, norm_layer=nn.LayerNorm):
            super().__init__()
            self.dim = dim  # 输入特征的channel = 96/192/384
            self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
            self.norm = norm_layer(4 * dim)  # LN
    
        def forward(self, x, H, W):
            """
            x: [bs, H_/4 * W/4, C=96]
            """
            B, L, C = x.shape   # B=8 C=96 L= H_/4*W/4
            assert L == H * W, "input feature has wrong size"
    
            x = x.view(B, H, W, C)  # [bs, H_/4 * W/4, C=96] -> [bs, H_/4, W_/4, C=96]
    
            # padding
            # 如果输入feature map的H,W不是2的整数倍,需要进行padding
            pad_input = (H % 2 == 1) or (W % 2 == 1)  # False
            if pad_input:  # 跳过
                # to pad the last 3 dimensions, starting from the last dimension and moving forward.
                # (C_front, C_back, W_left, W_right, H_top, H_bottom)
                # 注意这里的Tensor通道是[B, H, W, C],所以会和官方文档有些不同
                x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
    
            # 每隔一个像素取一个元素 有点像yolov5的focus层 最后一个特征 -> 4个下采样的特征
            # [bs, H_/4, W_/4, C=96]  ->  4 x [bs, H_/8, W_/8, C=96]
            x0 = x[:, 0::2, 0::2, :]  
            x1 = x[:, 1::2, 0::2, :] 
            x2 = x[:, 0::2, 1::2, :]  
            x3 = x[:, 1::2, 1::2, :]  
            
            # 4 x [bs, H_/8, W_/8, 96] -> [bs, H_/8, W_/8, 96*4] -> [bs, H_/8 * W_/8, 4*C]
            x = torch.cat([x0, x1, x2, x3], -1)  
            x = x.view(B, -1, 4 * C)  
    
            x = self.norm(x)  # LN
            
            # Linear 将通道从4C -> 2C  [bs, H_/8 * W_/8, C*4] -> [bs, H_/8 * W_/8, 2*C]
            x = self.reduction(x)  
    
            return x
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45

    五、总结

    为了解决ViT存在的问题:

    • 尺度问题:数据集物体大大小小,但是整个Encoder过程特征尺度是不变的,效果肯定不好;
    • 计算量大:划分patch,再把整张图片的所有patch都输入Encoder中,计算量太大;

    改进点:

    • Encode呈现金字塔形状。每过一个Stage对特征进行一次下采样,感受野在不停的增大,解决了尺度问题。所以Swin-Transformer不进适合分类任务,在下游检测、分割任务可以充分利用这种多尺度信息,检测效果很好;
    • 注意力机制放在一个窗口内部。不再把整张图片的所有patch都输入Encoder,而是将各个Patch单独的输入Encoder,解决了计算量太大的问题。

    关于第二点改进点还有很多的细节:

    • 提出Window Muti-head Self-Attention(W-MSA):把输入特征划分为一个个的windows窗口,只计算每个位置和当前windows窗口的所有位置的相关性Attention,其他窗口的不关心,这样就大大减少了计算量了;
    • W-MSA有一个问题,不同窗口完全不相关了,那不同窗口的位置之间不就没法交互了,所以作者又提出了Shift-Window Muti-head Self-Attention(SW-MSA)。
    • 特征图Shift操作其实很简单,就是特征某些行列平移,但是Shift了之后就会产生更多的窗口,计算量还是增加了,作者为了解决这个问题,引入了Mask,仍然是使用原先的窗口划分方式,但是用mask记录每个位置属于哪个窗口,相同窗口的位置mask=0,不同窗口的位置mask=-100,那么最后再用计算好的attention + mask,再softmax。于是,相同窗口的attention不变,不同窗口的attention=0,完美解决所有问题;
    • 作者还在WindowAttention中引入了relative_position_bias,使用Attention(Q,K,V) = SoftMax(Q*K的转置/scale + B)*V计算公式;

    六、一些问题

    6.1.为什么要W-MSA和SW-MSA混合使用?

    我的理解:单独的W-MSA和单独的SW-MSA其实都是固定的位置窗口(SW-MSA是对固定的区域进行shift,但是如果单独只使用SW-MSA,那么不还是固定的窗口),这样使用还是会有不同窗口无法信息交互的问题,但是混合起来使用,才能真正的起到交互作用。

    Reference

    b站: Swin Transformer论文精读【论文精读】

    b站: 12.1 Swin-Transformer网络结构详解

    b站: 12.2 使用Pytorch搭建Swin-Transformer网络

  • 相关阅读:
    企业如何搭建智能客服系统?
    如何在VSCode中带有参数的Debug(name、program、$file、args、pickArgs、指定虚拟环境)
    HSN:微调预训练ViT用于目标检测和语义分割,华南理工和阿里巴巴联合提出
    论文阅读(11) 操纵涡环来提高机动性 (2015)
    Dockerfile 制作常用命令总结
    蓝桥杯刷题|03普及-真题
    《三体》中的两大假说-射手与农场主
    Sass/Scss 入门不慌
    面试百问:如何提高自动化测试脚本稳定性
    Fast——Nginx
  • 原文地址:https://blog.csdn.net/qq_38253797/article/details/128133666