• Swin Transformer代码实现部分细节重点


    swin transformer

    1.patch-merging部分
    在这里插入图片描述
    代码:【amazing】

    		x0 = x[:, 0::2, 0::2, :]  # [B, H/2, W/2, C]  对应图片所有 1 的位置
            x1 = x[:, 1::2, 0::2, :]  # [B, H/2, W/2, C]  对应图片所有 3 的位置
            x2 = x[:, 0::2, 1::2, :]  # [B, H/2, W/2, C]  对应图片所有 2 的位置
            x3 = x[:, 1::2, 1::2, :]  # [B, H/2, W/2, C]  对应图片所有 4 的位置
            x = torch.cat([x0, x1, x2, x3], -1)  # [B, H/2, W/2, 4*C] 拼在一起,通道变为4倍
    
    		x = x.view(B, -1, 4 * C)  # [B, H/2*W/2, 4*C]
            x = self.norm(x)
            x = self.reduction(x)  # [B, H/2*W/2, 2*C]  self.reduction = nn.Linear(4*dim, 2*dim, bias=False)一个线性映射使通道变为2倍
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    2.create mask部分(有点懵)
    ![在这里插入图片描述](https://img-blog.csdnimg.cn/ebc36327a9b84806b96d6d50c9f12dcd.png在这里插入图片描述
    划分窗口
    在这里插入图片描述

    相同的数字是连续的区域
    代码:

    		h_slices = (slice(0, -self.window_size), #切片 [0,-3) 正着数是从第一个开始记为0,倒着数从最后一个开始记为-1
                        slice(-self.window_size, -self.shift_size),# [-3,-1)
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            cnt = 0
            for h in h_slices: # 给区域标号
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1
        # 划分window窗口
            mask_windows = window_partition(img_mask, self.window_size)  # [nW, Mh, Mw, 1]窗口个数,窗口宽,高,通道数
            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)  # [nW, Mh*Mw]
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)  # [nW, 1, Mh*Mw] - [nW, Mh*Mw, 1]
            # [nW, Mh*Mw, Mh*Mw] 利用广播机制
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
            return attn_mask
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18

    3.window attention
    相对位置编码
    整体流程(摘自博客)
    在这里插入图片描述
    增加维度,下图图示【以下这些维度操作,就很amazing!!!】
    在这里插入图片描述

    利用广播机制相减,得到相对位置编码(摘自B导视频)
    如下图中颜色对应的坐标相减
    在这里插入图片描述
    这是permute变换前后变化,从横纵坐标分离 到 横纵坐标和在一起
    在这里插入图片描述

    代码:

     # 相对位置编码
            # get pair-wise relative position index for each token inside the window
            #首先 生成绝对位置索引
            coords_h = torch.arange(self.window_size[0])
            coords_w = torch.arange(self.window_size[1])   # 生成网格坐标索引    堆叠
            coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij"))  # [2, Mh, Mw]
            coords_flatten = torch.flatten(coords, 1)  # [2, Mh*Mw] 并展开为2D向量
            # coords_flatten[:, None, :] 在一维处插入新维度  , coords_flatten[:, :, None] 在二维处插入新维度
                                        # [2, Mh*Mw, 1] - [2, 1, Mh*Mw]  利用广播机制 就是通过相减得到他们的相对位置关系
            relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # [2, Mh*Mw, Mh*Mw]
            relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # [Mh*Mw, Mh*Mw, 2] 调换位置
            #把二元索引变成一元索引
            relative_coords[:, :, 0] += self.window_size[0] - 1  # 坐标转换为从0开始
            relative_coords[:, :, 1] += self.window_size[1] - 1
            relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 #行坐标乘(2M-1)
            relative_position_index = relative_coords.sum(-1)  # [Mh*Mw, Mh*Mw] 最后一个维度求和
            self.register_buffer("relative_position_index", relative_position_index) #注册为不参与网络学习的变量,
                                                        # #作用是根据最终的相对位置索引 找到对应的可学习的相对位置编码
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
  • 相关阅读:
    API接口大全分享,含短信API、IP查询API。。。
    redisTemplate注入失败,异常为null值,原因探究
    Linux操作系统的基础指令
    设计模式(二)-创建者模式(2-0)-简单工厂模式
    TT time tunnel时空隧道命令使用场景
    [计算机提升] 用户和用户组
    【BOOST C++ 12 函数式编程】(4) Boost.Ref
    【八大排序算法】插入排序、希尔排序、选择排序、堆排序、冒泡排序、快速排序、归并排序、计数排序
    Java多线程基础
    【每周一测】Java阶段三第三周学习
  • 原文地址:https://blog.csdn.net/weixin_44040169/article/details/126911018