• VSSM VMamba实现


    VSSM

    Mamba实现可以参照之前的
    mamba_minimal系列
    论文地址:
    VMamba
    论文阅读
    VMamba:视觉状态空间模型
    代码地址:
    https://github.com/MzeroMiko/VMamba.git
    SS2D实现

    以分类任务用到的VMamba为例。

    维度变换

    操作的具体参数定义见初始化

    阶段维度
    输入x [ B , C , H , W ] [B, C, H, W] [B,C,H,W]
    embed [ B , H / 4 , W / 4 , C 1 ] [B, H/4, W/4, C_1 ] [B,H/4,W/4,C1]
    阶段1 [ B , H / 4 , W / 4 , C 1 ] [B, H/4, W/4, C_1 ] [B,H/4,W/4,C1]
    阶段2 [ B , H / 8 , W / 8 , C 2 ] [B, H/8, W/8, C_2 ] [B,H/8,W/8,C2]
    阶段3 [ B , H / 16 , W / 16 , C 3 ] [B, H/16, W/16, C_3 ] [B,H/16,W/16,C3]
    阶段4 [ B , H / 32 , W / 32 , C 4 ] [B, H/32, W/32, C_4 ] [B,H/32,W/32,C4]
    分类器 [ B , 1000 ] [B, 1000 ] [B,1000]

    在这里插入图片描述

    初始化

    参数定义说明
    in_chans3输入图像的通道数
    depths[2, 2, 9, 2]定义每层的VSS Block数
    dims[96, 192, 384, 768]定义每层的输出通道数
    downsample_versionv2下采样操作的版本
    patchembed_versionv1图像嵌入
    mlp_ratio4.0定义mlp隐藏维度缩放
    ssm_d_state16ssm隐状态的维度
    ssm_ratio2.0d_inner = d_state * ssm_ratio
    ssm_initv0ssm初始化版本
    forward_typev2ssm前向版本

    模型参数初始化

    大部分参数即SS2D,VSS块中的参数由定义的ssm初始化版本初始化,剩下的线性层和归一化层参数由下面的函数初始化。

        def _init_weights(self, m: nn.Module):
            if isinstance(m, nn.Linear):
                trunc_normal_(m.weight, std=.02)
                if isinstance(m, nn.Linear) and m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1.0)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    模型搭建

    def_make_layer

    构建VSSM的4个阶段,即4层,VSSBlock本身并不改变输入的尺寸,因此需要下采样模块将输出维度变换为下一阶段的输入维度

    def _make_layer(
            dim=96, 
            drop_path=[0.1, 0.1], 
            use_checkpoint=False, 
            norm_layer=nn.LayerNorm,
            downsample=nn.Identity(),
            # ===========================
            ssm_d_state=16,
            ssm_ratio=2.0,
            ssm_dt_rank="auto",       
            ssm_act_layer=nn.SiLU,
            ssm_conv=3,
            ssm_conv_bias=True,
            ssm_drop_rate=0.0, 
            ssm_init="v0",
            forward_type="v2",
            # ===========================
            mlp_ratio=4.0,
            mlp_act_layer=nn.GELU,
            mlp_drop_rate=0.0,
            **kwargs,
        ):
            depth = len(drop_path)
            blocks = []
            for d in range(depth):
                blocks.append(VSSBlock(
                    hidden_dim=dim, 
                    drop_path=drop_path[d],
                    norm_layer=norm_layer,
                    ssm_d_state=ssm_d_state,
                    ssm_ratio=ssm_ratio,
                    ssm_dt_rank=ssm_dt_rank,
                    ssm_act_layer=ssm_act_layer,
                    ssm_conv=ssm_conv,
                    ssm_conv_bias=ssm_conv_bias,
                    ssm_drop_rate=ssm_drop_rate,
                    ssm_init=ssm_init,
                    forward_type=forward_type,
                    mlp_ratio=mlp_ratio,
                    mlp_act_layer=mlp_act_layer,
                    mlp_drop_rate=mlp_drop_rate,
                    use_checkpoint=use_checkpoint,
                ))
            
            return nn.Sequential(OrderedDict(
                blocks=nn.Sequential(*blocks,),
                downsample=downsample,
            ))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    def _make_downsample

    默认下采样版本v2

    下采样模块,通过2D卷积之后,长宽变为原来的一半,通道数不变

        def _make_downsample(dim=96, out_dim=192, norm_layer=nn.LayerNorm):
            return nn.Sequential(
                Permute(0, 3, 1, 2),
                nn.Conv2d(dim, out_dim, kernel_size=2, stride=2),
                Permute(0, 2, 3, 1),
                norm_layer(out_dim),
            )
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    patch embed

    默认嵌入版本v1,对输入图像进行embed

    输入x维度 [ B , 3 , H , W ] [B, 3, H, W] [B,3,H,W],嵌入后通道维变为96, H = H p a t c h _ s i z e H = \frac{H}{patch\_size} H=patch_sizeH W = W p a t c h _ s i z e W = \frac{W}{patch\_size} W=patch_sizeW [ B , 96 , H 4 , W 4 ] [B, 96, \frac{H}{4}, \frac{W}{4}] [B,96,4H,4W]

     def _make_patch_embed(in_chans=3, embed_dim=96, patch_size=4, patch_norm=True, norm_layer=nn.LayerNorm):
            return nn.Sequential(
                nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True),
                Permute(0, 2, 3, 1),
                (norm_layer(embed_dim) if patch_norm else nn.Identity()), 
            )
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    第一至四阶段

    这几个阶段的差别在于每一层的VSSBlock数不同,由depths定义分别为 [2, 2, 9, 2],输出维度由dims定义分别为[96, 192, 384, 768]。其组成元素除一阶段外,均在VSSBlock前包含下采样模块以变换维度。

    具体介绍见VSSBlock

    分类器

    池化后长宽变为1,则变量尺寸变为 [ B , C , 1 , 1 ] [B, C, 1, 1] [B,C,1,1],展平后变为 [ B , C ] [B, C] [B,C]最后线性投影到类别维度1000

    [ B , 1000 ] [B, 1000] [B,1000]

    self.classifier = nn.Sequential(OrderedDict(
                norm=norm_layer(self.num_features), # B,H,W,C
                permute=Permute(0, 3, 1, 2),
                avgpool=nn.AdaptiveAvgPool2d(1),
                flatten=nn.Flatten(1),
                head=nn.Linear(self.num_features, num_classes),
            ))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    VSSBlock

    对于ssm分支来说,其输入输出维度不变为(B, H, W, d_model) ,对于mlp分支来说中间的隐藏维度根据mlp_ratio参数定义会有所增加,但是最后又会映射为原来的维度,因此整体上并不改变输入的维度。

    def __ init__

    主要分为两个分支ssm分支和mlp分支

    ssm分支

    主要组成部分是SS2D块
    SS2D实现

    if self.ssm_branch:
                self.norm = norm_layer(hidden_dim)
                self.op = _SS2D(
                    d_model=hidden_dim, 
                    d_state=ssm_d_state, 
                    ssm_ratio=ssm_ratio,
                    dt_rank=ssm_dt_rank,
                    act_layer=ssm_act_layer,
                    # ==========================
                    d_conv=ssm_conv,
                    conv_bias=ssm_conv_bias,
                    # ==========================
                    dropout=ssm_drop_rate,
                    # =========================
                    initialize=ssm_init,
                    forward_type=forward_type,
                )      
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17

    图中的SS2D和SS2D类的定义有偏差,简单来说是是包含SS2D块加一个残差连接,图中所示SS2D应表示状态空间模型SSM部分,即VSS块相比SS2D块只增加了残差连接和入口的归一化。如果定义了MLP分支,VSS块的输出还会经过一个残差连接的两层MLP

    在这里插入图片描述

    mlp分支
     if self.mlp_branch:
                self.norm2 = norm_layer(hidden_dim)
                mlp_hidden_dim = int(hidden_dim * mlp_ratio)
                self.mlp = Mlp(in_features=hidden_dim, hidden_features=mlp_hidden_dim, act_layer=mlp_act_layer, drop=mlp_drop_rate, channels_first=False)
                
     class Mlp(nn.Module):
        def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.,channels_first=False):
            super().__init__()
            out_features = out_features or in_features
            hidden_features = hidden_features or in_features
    
            Linear = partial(nn.Conv2d, kernel_size=1, padding=0) if channels_first else nn.Linear
            self.fc1 = Linear(in_features, hidden_features)
            self.act = act_layer()
            self.fc2 = Linear(hidden_features, out_features)
            self.drop = nn.Dropout(drop)
    
        def forward(self, x):
            x = self.fc1(x)
            x = self.act(x)
            x = self.drop(x)
            x = self.fc2(x)
            x = self.drop(x)
            return x
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24

    def forward

        def _forward(self, input: torch.Tensor):
            if self.ssm_branch:
                if self.post_norm:
                    x = input + self.drop_path(self.norm(self.op(input)))
                else:
                    x = input + self.drop_path(self.op(self.norm(input)))
            if self.mlp_branch:
                if self.post_norm:
                    x = x + self.drop_path(self.norm2(self.mlp(x))) # FFN
                else:
                    x = x + self.drop_path(self.mlp(self.norm2(x))) # FFN
            return x
    
         
        
        def forward(self, input: torch.Tensor):
            if self.use_checkpoint:
                return checkpoint.checkpoint(self._forward, input)
            else:
                return self._forward(input)
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
  • 相关阅读:
    (数论) 扩展gcd
    MySQL大数据量查询方案
    使用 PHP 和 MySQL 的投票和投票系统
    UG\NX二次开发 选择基准平面 UF_UI_select_with_single_dialog
    密码学奇妙之旅、01 CFB密文反馈模式、AES标准、Golang代码
    Spring MVC的执行流程
    【网络是怎么连接的】第二章(中):一个网络包的发出
    Arduino追光小车
    Linux内核7. 内存管理
    【C++】构造函数和析构函数
  • 原文地址:https://blog.csdn.net/weixin_45668967/article/details/136724764