• 【图像分类】2022-MPViT CVPR


    【图像分类】2022-MPViT CVPR

    论文链接:https://arxiv.org/abs/2112.11010

    论文代码:https://github.com/youngwanLEE/MPViT

    PPT简介: https://blog.csdn.net/Qingkaii/article/details/124398735

    1. 简介

    1.1 简介

    • 在这项工作中,作者以不同于现有Transformer的视角,探索多尺度path embedding与multi-path结构,提出了Multi-path Vision Transformer(MPViT)。

    • 通过使用 overlapping convolutional patch embedding,MPViT同时嵌入相同大小的patch特征。然后,将不同尺度的Token通过多条路径独立地输入Transformer encoders,并对生成的特征进行聚合,从而在同一特征级别上实现精细和粗糙的特征表示。

    • 在特征聚合步骤中,引入了一个global-to-local feature interaction(GLI)过程,该过程将卷积局部特征与Transformer的全局特征连接起来,同时利用了卷积的局部连通性和Transformer的全局上下文。

    因此本文作者将重点放在了图像的多尺度多路径上,通过对图片不同尺度分块及其构成的多路径结构,提升了图像分割中Transformer的精确程度。
    image-20220730105052612

    1.2 贡献

    • 提出了一个具有多路径结构的多尺度嵌入方法,用于同时表示密集预测任务的精细和粗糙特征。
    • 介绍了全局到本地特征交互(GLI),同时利用卷积的局部连通性和Transformer的全局上下文来表示特征。
    • 性能优于最先进的vit,同时有更少的参数和运算次数。

    2. 网络

    2.1 整体架构

    • 首先对输入的图像做卷积提取特征,
    • 而后主要分成了四个Transformer阶段,如图左侧一列所示,
    • 中间一列是每个阶段中两个小块的展开分析图,
    • 右侧一列则是对多路径模块中Transformer(包括局部卷积)以及全局信息模块的图解。

    image-20220730105107831

    2.2 Conv-stem

    本模块由两个3×3卷积组成,可以在不丢失显著信息的情况下对图片进行特征提取以及尺度的减小

    输入图像大小为:H×W×3

    两层卷积:采用两个3×3的卷积,通道分别为C2/2C2,stride为2,

    输出图像:生成特征的大小为H/4×W/4×C2,其中C2为stage 2的通道大小。

    说明:

    1.每个卷积之后都是Batch Normalization 和一个Hardswish激活函数。

    2.从stage 2到stage 5,在每个阶段对所提出的Multi-scale Patch Embedding(MS-PatchEmbed)和Multi-path Transformer(MP-Transformer)块进行堆叠。

    2.3 Multi-Scale Patch Embedding

    多尺度Patch Embedding结构如下,对于输入特征图,使用不同大小的卷积核来得到不同尺度的特征信息(论文这么写的,但是源码看到卷积核都是3),

    为了减少参数,使用3x3的卷积核叠加来增加感受野达到5x5、7x7卷积核的感受野,同时使用深度可分离卷积来减少参数

    image-20220730105503147

    输入图像:

    stage i 的输入X,通过一个k×k的2D卷积,s为stride,p为 padding。

    输出的token map F的高度和宽度如下:
    H i = ⌊ H i − 1 − k + 2 p s ⌋ , W i = ⌊ W i − 1 − k + 2 p s ⌋ H_i=\lfloor \frac{H_{i-1}-k+2p}{s}\rfloor,W_{i=}\lfloor \frac{W_{i-1}-k+2p}{s}\rfloor Hi=sHi1k+2p,Wi=sWi1k+2p
    通过改变stride和padding来调整token的序列长度,即不同块尺寸可以具有相同尺寸的输出。

    因此,我们构建了不同核尺寸的并行卷积块嵌入层,如序列长度相同但块尺寸可以为3×3,5×5,7×7

    例如,如图1所示,可以生成相同序列长度,不同大小的vision token,patch大小分别为3×3,5×5,7×7

    实践

    • 由于堆叠同尺寸卷积可以提升感受野且具有更少的参数量,
      选择两个连续的3×3卷积层构建5×5感受野,采用三个3×3卷积构建7×7感受野
    • 对于triple-path结构,使用三个连续的3×3卷积,通道大小为C’,padding为1,步幅为s,其中s在降低空间分辨率时为2,否则为1。
      因此,给定conv-stem的输出X,通过MS-PatchEmbed可以得到相同大小为 H i s × W i s × C ′ \frac{H_i}{s}\times\frac{W_i}{s}\times C^\prime sHi×sWi×C的特征 F 3 × 3 ( X i ) , F 5 × 5 ( X i ) , F 7 × 7 ( X i ) F_{3\times 3}(X_i),F_{5\times 5}(X_i),F_{7\times 7}(X_i) F3×3(Xi),F5×5(Xi),F7×7(Xi)
    • 为了减少模型参数和计算开销,采用3×3深度可分离卷积,包括3×3深度卷积和1×1点卷积。
    • 每个卷积之后都是Batch Normalization 和一个Hardswish激活函数。

    接着,不同大小的token embedding features 分别输入到transformer encoder中。

    image-20220731114132146

    2.4 Multi-path Transformer

    原因:

    Transformer中的self-attention可以捕获长期依赖关系(即全局上下文),但它很可能会忽略每个patch中的结构性信息和局部关系。

    相反,cnn可以利用平移不变性中的局部连通性,使得CNN在对视觉对象进行分类时,对纹理有更强的依赖性,而不是形状。

    因此,MPViT以一种互补的方式将CNN与Transformer结合起来。

    组成:下面的多路径Transformer和局部特征卷积,上面的Global-to-Local Feature Interaction

    在多路径的特征进行自注意力(局部卷积)计算以及全局上下文信息交互后,所有特征会做一个Concat经过激活函数后进入下一阶段。

    image-20220731114722024

    2.4.1 多路径Transformer和局部特征卷积

    ansformer可以关注到较远距离的相关性,但是卷积网络却能更好地对图像的局部上下文特征进行提取,因此作者同时加入了这两个互补的操作,实现了本部分。

    Transformer

    由于每个图像块内作者都使用了自注意力,并且存在多个路径,因此为了减小计算压力,作者使用了CoaT中提出的有效的因素分解自注意(将复杂度降低为线性)

    FactorAtt ⁡ ( Q , K , V ) = Q C ( softmax ⁡ ( K ) ⊤ V ) \operatorname{FactorAtt}(Q, K, V)=\frac{Q}{\sqrt{C}}\left(\operatorname{softmax}(K)^{\top} V\right) FactorAtt(Q,K,V)=C Q(softmax(K)V)

    CNN

    为了表示局部特征 L,采用了一个 depthwise residual bottleneck block,包括1×1卷积、3×3深度卷积和1×1卷积和残差连接。在三个Transformer模块的左侧存在一个卷积操作,其实就是通过卷积的局部性,将图像的局部上下文引入模型中,多了这些上下文信息可以弥补Transformer对于局部语义理解的不足

    image-20220731114739396

    在原始的计算attention的过程中,空间复杂度是O( N ∗ N N*N NN), 时间复杂度是O( N ∗ N ∗ C N*N*C NNC),

    Attn ⁡ ( X ) = softmax ⁡ ( Q K T C ) V \operatorname{Attn}(X)=\operatorname{softmax}\left(\frac{Q K^{T}}{\sqrt{C}}\right) V Attn(X)=softmax(C QKT)V

    • 一个 query 给 n 个 key - value pair ,这个 query 会跟每个 key - value pair 做内积,会产生 n 个相似度值。传入 softmax 得到 n 个非负、求和为 1 的权重值。
    • output 中 value 的权重 = 查询 query 和对应的 key 的相似度
      通常用内积实现,用来衡量每个key对每个query的影响大小

    • 把 softmax 得到的权重值 与 value 矩阵 V 相乘 得到 attention 输出。

      N、C分别表示 tokens数量和 embedding维度。

    Factorized Attention Mechanism: 空间复杂度 O ( N C ) O(NC) O(NC),时间复杂度 O ( N C 2 ) O(NC^2) O(NC2)。复杂度变成原来的 C N \frac{C}{N} NC

    FactorAtt ⁡ ( Q , K , V ) = Q C ( softmax ⁡ ( K ) ⊤ V ) \operatorname { FactorAtt }(Q, K, V)=\frac{Q}{\sqrt{C}}\left(\operatorname{softmax}(K)^{\top} V\right) FactorAtt(Q,K,V)=C Q(softmax(K)V)

    为了降低复杂度,类似于LambdaNet中的做法(以恒等函数和softmax的注意力分解机制:),将attention的方法改为如下形式

    • 通过使用2个函数对其进行分解,并一起计算第2个矩阵乘法(key和value)来近似softmax attention map:
    • 为了归一化效果将比例因子 根号下c分之一添加回去,带来了更好的性能

      FactorAtt ⁡ ( X ) = ϕ ( Q ) ( ψ ( K ) ⊤ V ) \operatorname{FactorAtt}(X)=\phi(Q)\left(\psi(K)^{\top} V\right) FactorAtt(X)=ϕ(Q)(ψ(K)V)

    另一方面在计算原始的attention时可以明确解释attention是当前位置与其他位置的相似度,

    但在factor attn的计算过程中并不是很好解释,而且丢失了内积过程。

    虽然FactorAttn不是对attn的直接近似,但是也是一种泛化的注意力机制有query,key和value

    2.4.2 Global-to-Local Feature Interaction

    作用

    将局部特征和全局特征聚合起来:通过串联来执行

    对输入特征做了一个Concat并进行了1×1卷积(H(·)是一个学习与特征交互的函数),该模块同时输入了存在远距离关注的Transformer以及提取局部上下文关系的卷积操作,因此可以认为就是对本阶段提取到的图像全局以及局部语义的特征融合,充分利用了图像的信息。

    image-20220731114751442

    2.5 消融实验

    在这里插入图片描述

    3. 代码

    # --------------------------------------------------------------------------------
    # MPViT: Multi-Path Vision Transformer for Dense Prediction
    # Copyright (c) 2022 Electronics and Telecommunications Research Institute (ETRI).
    # All Rights Reserved.
    # Written by Youngwan Lee
    # This source code is licensed(Dual License(GPL3.0 & Commercial)) under the license found in the
    # LICENSE file in the root directory of this source tree.
    # --------------------------------------------------------------------------------
    # References:
    # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
    # CoaT: https://github.com/mlpc-ucsd/CoaT
    # --------------------------------------------------------------------------------
    
    
    import math
    from functools import partial
    
    import numpy as np
    import torch
    from einops import rearrange
    from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
    from timm.models.layers import DropPath, trunc_normal_
    from timm.models.registry import register_model
    from torch import einsum, nn
    
    __all__ = [
        "mpvit_tiny",
        "mpvit_xsmall",
        "mpvit_small",
        "mpvit_base",
    ]
    
    
    def _cfg_mpvit(url="", **kwargs):
        """configuration of mpvit."""
        return {
            "url": url,
            "num_classes": 1000,
            "input_size": (3, 224, 224),
            "pool_size": None,
            "crop_pct": 0.9,
            "interpolation": "bicubic",
            "mean": IMAGENET_DEFAULT_MEAN,
            "std": IMAGENET_DEFAULT_STD,
            "first_conv": "patch_embed.proj",
            "classifier": "head",
            **kwargs,
        }
    
    
    class Mlp(nn.Module):
        """Feed-forward network (FFN, a.k.a.
    
        MLP) class.
        """
        def __init__(
            self,
            in_features,
            hidden_features=None,
            out_features=None,
            act_layer=nn.GELU,
            drop=0.0,
        ):
            super().__init__()
            out_features = out_features or in_features
            hidden_features = hidden_features or in_features
            self.fc1 = nn.Linear(in_features, hidden_features)
            self.act = act_layer()
            self.fc2 = nn.Linear(hidden_features, out_features)
            self.drop = nn.Dropout(drop)
    
        def forward(self, x):
            """foward function"""
            x = self.fc1(x)
            x = self.act(x)
            x = self.drop(x)
            x = self.fc2(x)
            x = self.drop(x)
            return x
    
    
    class Conv2d_BN(nn.Module):
        """Convolution with BN module."""
        def __init__(
            self,
            in_ch,
            out_ch,
            kernel_size=1,
            stride=1,
            pad=0,
            dilation=1,
            groups=1,
            bn_weight_init=1,
            norm_layer=nn.BatchNorm2d,
            act_layer=None,
        ):
            super().__init__()
    
            self.conv = torch.nn.Conv2d(in_ch,
                                        out_ch,
                                        kernel_size,
                                        stride,
                                        pad,
                                        dilation,
                                        groups,
                                        bias=False)
            self.bn = norm_layer(out_ch)
            torch.nn.init.constant_(self.bn.weight, bn_weight_init)
            torch.nn.init.constant_(self.bn.bias, 0)
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    # Note that there is no bias due to BN
                    fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                    m.weight.data.normal_(mean=0.0, std=np.sqrt(2.0 / fan_out))
    
            self.act_layer = act_layer() if act_layer is not None else nn.Identity(
            )
    
        def forward(self, x):
            """foward function"""
            x = self.conv(x)
            x = self.bn(x)
            x = self.act_layer(x)
    
            return x
    
    
    class DWConv2d_BN(nn.Module):
        """Depthwise Separable Convolution with BN module."""
        def __init__(
            self,
            in_ch,
            out_ch,
            kernel_size=1,
            stride=1,
            norm_layer=nn.BatchNorm2d,
            act_layer=nn.Hardswish,
            bn_weight_init=1,
        ):
            super().__init__()
    
            # dw
            self.dwconv = nn.Conv2d(
                in_ch,
                out_ch,
                kernel_size,
                stride,
                (kernel_size - 1) // 2,
                groups=out_ch,
                bias=False,
            )
            # pw-linear
            self.pwconv = nn.Conv2d(out_ch, out_ch, 1, 1, 0, bias=False)
            self.bn = norm_layer(out_ch)
            self.act = act_layer() if act_layer is not None else nn.Identity()
    
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                    m.weight.data.normal_(0, math.sqrt(2.0 / n))
                    if m.bias is not None:
                        m.bias.data.zero_()
                elif isinstance(m, nn.BatchNorm2d):
                    m.weight.data.fill_(bn_weight_init)
                    m.bias.data.zero_()
    
        def forward(self, x):
            """
            foward function
            """
            x = self.dwconv(x)
            x = self.pwconv(x)
            x = self.bn(x)
            x = self.act(x)
    
            return x
    
    
    class DWCPatchEmbed(nn.Module):
        """Depthwise Convolutional Patch Embedding layer Image to Patch
        Embedding."""
        def __init__(self,
                     in_chans=3,
                     embed_dim=768,
                     patch_size=16,
                     stride=1,
                     act_layer=nn.Hardswish):
            super().__init__()
    
            self.patch_conv = DWConv2d_BN(
                in_chans,
                embed_dim,
                kernel_size=patch_size,
                stride=stride,
                act_layer=act_layer,
            )
    
        def forward(self, x):
            """foward function"""
            x = self.patch_conv(x)
    
            return x
    
    
    class Patch_Embed_stage(nn.Module):
        """Depthwise Convolutional Patch Embedding stage comprised of
        `DWCPatchEmbed` layers."""
        def __init__(self, embed_dim, num_path=4, isPool=False):
            super(Patch_Embed_stage, self).__init__()
    
            self.patch_embeds = nn.ModuleList([
                DWCPatchEmbed(
                    in_chans=embed_dim,
                    embed_dim=embed_dim,
                    patch_size=3,
                    stride=2 if isPool and idx == 0 else 1,
                ) for idx in range(num_path)
            ])
    
        def forward(self, x):
            """foward function"""
            att_inputs = []
            for pe in self.patch_embeds:
                x = pe(x)
                att_inputs.append(x)
    
            return att_inputs
    
    
    class ConvPosEnc(nn.Module):
        """Convolutional Position Encoding.
    
        Note: This module is similar to the conditional position encoding in CPVT.
        """
        def __init__(self, dim, k=3):
            """init function"""
            super(ConvPosEnc, self).__init__()
    
            self.proj = nn.Conv2d(dim, dim, k, 1, k // 2, groups=dim)
    
        def forward(self, x, size):
            """foward function"""
            B, N, C = x.shape
            H, W = size
    
            feat = x.transpose(1, 2).view(B, C, H, W)
            x = self.proj(feat) + feat
            x = x.flatten(2).transpose(1, 2)
    
            return x
    
    
    class ConvRelPosEnc(nn.Module):
        """Convolutional relative position encoding."""
        def __init__(self, Ch, h, window):
            """Initialization.
    
            Ch: Channels per head.
            h: Number of heads.
            window: Window size(s) in convolutional relative positional encoding.
                    It can have two forms:
                    1. An integer of window size, which assigns all attention heads
                       with the same window size in ConvRelPosEnc.
                    2. A dict mapping window size to #attention head splits
                       (e.g. {window size 1: #attention head split 1, window size
                                          2: #attention head split 2})
                       It will apply different window size to
                       the attention head splits.
            """
            super().__init__()
    
            if isinstance(window, int):
                # Set the same window size for all attention heads.
                window = {window: h}
                self.window = window
            elif isinstance(window, dict):
                self.window = window
            else:
                raise ValueError()
    
            self.conv_list = nn.ModuleList()
            self.head_splits = []
            for cur_window, cur_head_split in window.items():
                dilation = 1  # Use dilation=1 at default.
                padding_size = (cur_window + (cur_window - 1) *
                                (dilation - 1)) // 2
                cur_conv = nn.Conv2d(
                    cur_head_split * Ch,
                    cur_head_split * Ch,
                    kernel_size=(cur_window, cur_window),
                    padding=(padding_size, padding_size),
                    dilation=(dilation, dilation),
                    groups=cur_head_split * Ch,
                )
                self.conv_list.append(cur_conv)
                self.head_splits.append(cur_head_split)
            self.channel_splits = [x * Ch for x in self.head_splits]
    
        def forward(self, q, v, size):
            """foward function"""
            B, h, N, Ch = q.shape
            H, W = size
    
            # We don't use CLS_TOKEN
            q_img = q
            v_img = v
    
            # Shape: [B, h, H*W, Ch] -> [B, h*Ch, H, W].
            v_img = rearrange(v_img, "B h (H W) Ch -> B (h Ch) H W", H=H, W=W)
            # Split according to channels.
            v_img_list = torch.split(v_img, self.channel_splits, dim=1)
            conv_v_img_list = [
                conv(x) for conv, x in zip(self.conv_list, v_img_list)
            ]
            conv_v_img = torch.cat(conv_v_img_list, dim=1)
            # Shape: [B, h*Ch, H, W] -> [B, h, H*W, Ch].
            conv_v_img = rearrange(conv_v_img, "B (h Ch) H W -> B h (H W) Ch", h=h)
    
            EV_hat_img = q_img * conv_v_img
            EV_hat = EV_hat_img
            return EV_hat
    
    
    class FactorAtt_ConvRelPosEnc(nn.Module):
        """Factorized attention with convolutional relative position encoding
        class."""
        def __init__(
            self,
            dim,
            num_heads=8,
            qkv_bias=False,
            qk_scale=None,
            attn_drop=0.0,
            proj_drop=0.0,
            shared_crpe=None,
        ):
            super().__init__()
            self.num_heads = num_heads
            head_dim = dim // num_heads
            self.scale = qk_scale or head_dim**-0.5
    
            self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
            self.attn_drop = nn.Dropout(attn_drop)
            self.proj = nn.Linear(dim, dim)
            self.proj_drop = nn.Dropout(proj_drop)
    
            # Shared convolutional relative position encoding.
            self.crpe = shared_crpe
    
        def forward(self, x, size):
            """foward function"""
            B, N, C = x.shape
    
            # Generate Q, K, V.
            qkv = (self.qkv(x).reshape(B, N, 3, self.num_heads,
                                       C // self.num_heads).permute(2, 0, 3, 1, 4))
            q, k, v = qkv[0], qkv[1], qkv[2]
    
            # Factorized attention.
            k_softmax = k.softmax(dim=2)
            k_softmax_T_dot_v = einsum("b h n k, b h n v -> b h k v", k_softmax, v)
            factor_att = einsum("b h n k, b h k v -> b h n v", q,
                                k_softmax_T_dot_v)
    
            # Convolutional relative position encoding.
            crpe = self.crpe(q, v, size=size)
    
            # Merge and reshape.
            x = self.scale * factor_att + crpe
            x = x.transpose(1, 2).reshape(B, N, C)
    
            # Output projection.
            x = self.proj(x)
            x = self.proj_drop(x)
    
            return x
    
    
    class MHCABlock(nn.Module):
        """Multi-Head Convolutional self-Attention block."""
        def __init__(
            self,
            dim,
            num_heads,
            mlp_ratio=3,
            drop_path=0.0,
            qkv_bias=True,
            qk_scale=None,
            norm_layer=partial(nn.LayerNorm, eps=1e-6),
            shared_cpe=None,
            shared_crpe=None,
        ):
            super().__init__()
    
            self.cpe = shared_cpe
            self.crpe = shared_crpe
            self.factoratt_crpe = FactorAtt_ConvRelPosEnc(
                dim,
                num_heads=num_heads,
                qkv_bias=qkv_bias,
                qk_scale=qk_scale,
                shared_crpe=shared_crpe,
            )
            self.mlp = Mlp(in_features=dim, hidden_features=dim * mlp_ratio)
            self.drop_path = DropPath(
                drop_path) if drop_path > 0.0 else nn.Identity()
    
            self.norm1 = norm_layer(dim)
            self.norm2 = norm_layer(dim)
    
        def forward(self, x, size):
            """foward function"""
            if self.cpe is not None:
                x = self.cpe(x, size)
            cur = self.norm1(x)
            x = x + self.drop_path(self.factoratt_crpe(cur, size))
    
            cur = self.norm2(x)
            x = x + self.drop_path(self.mlp(cur))
            return x
    
    
    class MHCAEncoder(nn.Module):
        """Multi-Head Convolutional self-Attention Encoder comprised of `MHCA`
        blocks."""
        def __init__(
            self,
            dim,
            num_layers=1,
            num_heads=8,
            mlp_ratio=3,
            drop_path_list=[],
            qk_scale=None,
            crpe_window={
                3: 2,
                5: 3,
                7: 3
            },
        ):
            super().__init__()
    
            self.num_layers = num_layers
            self.cpe = ConvPosEnc(dim, k=3)
            self.crpe = ConvRelPosEnc(Ch=dim // num_heads,
                                      h=num_heads,
                                      window=crpe_window)
            self.MHCA_layers = nn.ModuleList([
                MHCABlock(
                    dim,
                    num_heads=num_heads,
                    mlp_ratio=mlp_ratio,
                    drop_path=drop_path_list[idx],
                    qk_scale=qk_scale,
                    shared_cpe=self.cpe,
                    shared_crpe=self.crpe,
                ) for idx in range(self.num_layers)
            ])
    
        def forward(self, x, size):
            """foward function"""
            H, W = size
            B = x.shape[0]
            for layer in self.MHCA_layers:
                x = layer(x, (H, W))
    
            # return x's shape : [B, N, C] -> [B, C, H, W]
            x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
            return x
    
    
    class ResBlock(nn.Module):
        """Residual block for convolutional local feature."""
        def __init__(
            self,
            in_features,
            hidden_features=None,
            out_features=None,
            act_layer=nn.Hardswish,
            norm_layer=nn.BatchNorm2d,
        ):
            super().__init__()
    
            out_features = out_features or in_features
            hidden_features = hidden_features or in_features
            self.conv1 = Conv2d_BN(in_features,
                                   hidden_features,
                                   act_layer=act_layer)
            self.dwconv = nn.Conv2d(
                hidden_features,
                hidden_features,
                3,
                1,
                1,
                bias=False,
                groups=hidden_features,
            )
            self.norm = norm_layer(hidden_features)
            self.act = act_layer()
            self.conv2 = Conv2d_BN(hidden_features, out_features)
            self.apply(self._init_weights)
    
        def _init_weights(self, m):
            """
            initialization
            """
            if isinstance(m, nn.Conv2d):
                fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                fan_out //= m.groups
                m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
    
        def forward(self, x):
            """foward function"""
            identity = x
            feat = self.conv1(x)
            feat = self.dwconv(feat)
            feat = self.norm(feat)
            feat = self.act(feat)
            feat = self.conv2(feat)
    
            return identity + feat
    
    
    class MHCA_stage(nn.Module):
        """Multi-Head Convolutional self-Attention stage comprised of `MHCAEncoder`
        layers."""
        def __init__(
            self,
            embed_dim,
            out_embed_dim,
            num_layers=1,
            num_heads=8,
            mlp_ratio=3,
            num_path=4,
            drop_path_list=[],
        ):
            super().__init__()
    
            self.mhca_blks = nn.ModuleList([
                MHCAEncoder(
                    embed_dim,
                    num_layers,
                    num_heads,
                    mlp_ratio,
                    drop_path_list=drop_path_list,
                ) for _ in range(num_path)
            ])
    
            self.InvRes = ResBlock(in_features=embed_dim, out_features=embed_dim)
            self.aggregate = Conv2d_BN(embed_dim * (num_path + 1),
                                       out_embed_dim,
                                       act_layer=nn.Hardswish)
    
        def forward(self, inputs):
            """foward function"""
            att_outputs = [self.InvRes(inputs[0])]
            for x, encoder in zip(inputs, self.mhca_blks):
                # [B, C, H, W] -> [B, N, C]
                _, _, H, W = x.shape
                x = x.flatten(2).transpose(1, 2)
                att_outputs.append(encoder(x, size=(H, W)))
    
            out_concat = torch.cat(att_outputs, dim=1)
            out = self.aggregate(out_concat)
    
            return out
    
    
    class Cls_head(nn.Module):
        """a linear layer for classification."""
        def __init__(self, embed_dim, num_classes):
            """initialization"""
            super().__init__()
    
            self.cls = nn.Linear(embed_dim, num_classes)
    
        def forward(self, x):
            """foward function"""
            # (B, C, H, W) -> (B, C, 1)
    
            x = nn.functional.adaptive_avg_pool2d(x, 1).flatten(1)
            # Shape : [B, C]
            out = self.cls(x)
            return out
    
    
    def dpr_generator(drop_path_rate, num_layers, num_stages):
        """Generate drop path rate list following linear decay rule."""
        dpr_list = [
            x.item() for x in torch.linspace(0, drop_path_rate, sum(num_layers))
        ]
        dpr = []
        cur = 0
        for i in range(num_stages):
            dpr_per_stage = dpr_list[cur:cur + num_layers[i]]
            dpr.append(dpr_per_stage)
            cur += num_layers[i]
    
        return dpr
    
    
    class MPViT(nn.Module):
        """Multi-Path ViT class."""
        def __init__(
            self,
            img_size=224,
            num_stages=4,
            num_path=[4, 4, 4, 4],
            num_layers=[1, 1, 1, 1],
            embed_dims=[64, 128, 256, 512],
            mlp_ratios=[8, 8, 4, 4],
            num_heads=[8, 8, 8, 8],
            drop_path_rate=0.0,
            in_chans=3,
            num_classes=1000,
            **kwargs,
        ):
            super().__init__()
    
            self.num_classes = num_classes
            self.num_stages = num_stages
    
            dpr = dpr_generator(drop_path_rate, num_layers, num_stages)
    
            self.stem = nn.Sequential(
                Conv2d_BN(
                    in_chans,
                    embed_dims[0] // 2,
                    kernel_size=3,
                    stride=2,
                    pad=1,
                    act_layer=nn.Hardswish,
                ),
                Conv2d_BN(
                    embed_dims[0] // 2,
                    embed_dims[0],
                    kernel_size=3,
                    stride=2,
                    pad=1,
                    act_layer=nn.Hardswish,
                ),
            )
    
            # Patch embeddings.
            self.patch_embed_stages = nn.ModuleList([
                Patch_Embed_stage(
                    embed_dims[idx],
                    num_path=num_path[idx],
                    isPool=False if idx == 0 else True,
                ) for idx in range(self.num_stages)
            ])
    
            # Multi-Head Convolutional Self-Attention (MHCA)
            self.mhca_stages = nn.ModuleList([
                MHCA_stage(
                    embed_dims[idx],
                    embed_dims[idx + 1]
                    if not (idx + 1) == self.num_stages else embed_dims[idx],
                    num_layers[idx],
                    num_heads[idx],
                    mlp_ratios[idx],
                    num_path[idx],
                    drop_path_list=dpr[idx],
                ) for idx in range(self.num_stages)
            ])
    
            # Classification head.
            self.cls_head = Cls_head(embed_dims[-1], num_classes)
    
            self.apply(self._init_weights)
    
        def _init_weights(self, m):
            """initialization"""
            if isinstance(m, nn.Linear):
                trunc_normal_(m.weight, std=0.02)
                if isinstance(m, nn.Linear) and m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1.0)
    
        def get_classifier(self):
            """get classifier function"""
            return self.head
    
        def forward_features(self, x):
            """forward feature function"""
    
            # x's shape : [B, C, H, W]
    
            x = self.stem(x)  # Shape : [B, C, H/4, W/4]
    
            for idx in range(self.num_stages):
                att_inputs = self.patch_embed_stages[idx](x)
                x = self.mhca_stages[idx](att_inputs)
    
            return x
    
        def forward(self, x):
            """foward function"""
            x = self.forward_features(x)
    
            # cls head
            out = self.cls_head(x)
            return out
    
    
    @register_model
    def mpvit_tiny(**kwargs):
        """mpvit_tiny :
    
        - #paths : [2, 3, 3, 3]
        - #layers : [1, 2, 4, 1]
        - #channels : [64, 96, 176, 216]
        - MLP_ratio : 2
        Number of params: 5843736
        FLOPs : 1654163812
        Activations : 16641952
        """
    
        model = MPViT(
            img_size=224,
            num_stages=4,
            num_path=[2, 3, 3, 3],
            num_layers=[1, 2, 4, 1],
            embed_dims=[64, 96, 176, 216],
            mlp_ratios=[2, 2, 2, 2],
            num_heads=[8, 8, 8, 8],
            **kwargs,
        )
        model.default_cfg = _cfg_mpvit()
        return model
    
    
    @register_model
    def mpvit_xsmall(**kwargs):
        """mpvit_xsmall :
    
        - #paths : [2, 3, 3, 3]
        - #layers : [1, 2, 4, 1]
        - #channels : [64, 128, 192, 256]
        - MLP_ratio : 4
        Number of params : 10573448
        FLOPs : 2971396560
        Activations : 21983464
        """
    
        model = MPViT(
            img_size=224,
            num_stages=4,
            num_path=[2, 3, 3, 3],
            num_layers=[1, 2, 4, 1],
            embed_dims=[64, 128, 192, 256],
            mlp_ratios=[4, 4, 4, 4],
            num_heads=[8, 8, 8, 8],
            **kwargs,
        )
        model.default_cfg = _cfg_mpvit()
        return model
    
    
    @register_model
    def mpvit_small(**kwargs):
        """mpvit_small :
    
        - #paths : [2, 3, 3, 3]
        - #layers : [1, 3, 6, 3]
        - #channels : [64, 128, 216, 288]
        - MLP_ratio : 4
        Number of params : 22892400
        FLOPs : 4799650824
        Activations : 30601880
        """
    
        model = MPViT(
            img_size=224,
            num_stages=4,
            num_path=[2, 3, 3, 3],
            num_layers=[1, 3, 6, 3],
            embed_dims=[64, 128, 216, 288],
            mlp_ratios=[4, 4, 4, 4],
            num_heads=[8, 8, 8, 8],
            **kwargs,
        )
        model.default_cfg = _cfg_mpvit()
        return model
    
    
    @register_model
    def mpvit_base(**kwargs):
        """mpvit_base :
    
        - #paths : [2, 3, 3, 3]
        - #layers : [1, 3, 8, 3]
        - #channels : [128, 224, 368, 480]
        MLP_ratio : 4
        Number of params: 74845976
        FLOPs : 16445326240
        Activations : 60204392
        """
    
        model = MPViT(
            img_size=224,
            num_stages=4,
            num_path=[2, 3, 3, 3],
            num_layers=[1, 3, 8, 3],
            embed_dims=[128, 224, 368, 480],
            mlp_ratios=[4, 4, 4, 4],
            num_heads=[8, 8, 8, 8],
            **kwargs,
        )
        model.default_cfg = _cfg_mpvit()
        return model
    
    
    if __name__ == "__main__":
        model = mpvit_xsmall()
    
        from thop import profile
        # model = convnext_tiny(num_classes=5)
        input = torch.randn(1, 3, 224, 224)
        flops, params = profile(model, inputs=(input,))
        print("flops:{:.3f}G".format(flops /1e9))
        print("params:{:.3f}M".format(params /1e6))
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
    • 168
    • 169
    • 170
    • 171
    • 172
    • 173
    • 174
    • 175
    • 176
    • 177
    • 178
    • 179
    • 180
    • 181
    • 182
    • 183
    • 184
    • 185
    • 186
    • 187
    • 188
    • 189
    • 190
    • 191
    • 192
    • 193
    • 194
    • 195
    • 196
    • 197
    • 198
    • 199
    • 200
    • 201
    • 202
    • 203
    • 204
    • 205
    • 206
    • 207
    • 208
    • 209
    • 210
    • 211
    • 212
    • 213
    • 214
    • 215
    • 216
    • 217
    • 218
    • 219
    • 220
    • 221
    • 222
    • 223
    • 224
    • 225
    • 226
    • 227
    • 228
    • 229
    • 230
    • 231
    • 232
    • 233
    • 234
    • 235
    • 236
    • 237
    • 238
    • 239
    • 240
    • 241
    • 242
    • 243
    • 244
    • 245
    • 246
    • 247
    • 248
    • 249
    • 250
    • 251
    • 252
    • 253
    • 254
    • 255
    • 256
    • 257
    • 258
    • 259
    • 260
    • 261
    • 262
    • 263
    • 264
    • 265
    • 266
    • 267
    • 268
    • 269
    • 270
    • 271
    • 272
    • 273
    • 274
    • 275
    • 276
    • 277
    • 278
    • 279
    • 280
    • 281
    • 282
    • 283
    • 284
    • 285
    • 286
    • 287
    • 288
    • 289
    • 290
    • 291
    • 292
    • 293
    • 294
    • 295
    • 296
    • 297
    • 298
    • 299
    • 300
    • 301
    • 302
    • 303
    • 304
    • 305
    • 306
    • 307
    • 308
    • 309
    • 310
    • 311
    • 312
    • 313
    • 314
    • 315
    • 316
    • 317
    • 318
    • 319
    • 320
    • 321
    • 322
    • 323
    • 324
    • 325
    • 326
    • 327
    • 328
    • 329
    • 330
    • 331
    • 332
    • 333
    • 334
    • 335
    • 336
    • 337
    • 338
    • 339
    • 340
    • 341
    • 342
    • 343
    • 344
    • 345
    • 346
    • 347
    • 348
    • 349
    • 350
    • 351
    • 352
    • 353
    • 354
    • 355
    • 356
    • 357
    • 358
    • 359
    • 360
    • 361
    • 362
    • 363
    • 364
    • 365
    • 366
    • 367
    • 368
    • 369
    • 370
    • 371
    • 372
    • 373
    • 374
    • 375
    • 376
    • 377
    • 378
    • 379
    • 380
    • 381
    • 382
    • 383
    • 384
    • 385
    • 386
    • 387
    • 388
    • 389
    • 390
    • 391
    • 392
    • 393
    • 394
    • 395
    • 396
    • 397
    • 398
    • 399
    • 400
    • 401
    • 402
    • 403
    • 404
    • 405
    • 406
    • 407
    • 408
    • 409
    • 410
    • 411
    • 412
    • 413
    • 414
    • 415
    • 416
    • 417
    • 418
    • 419
    • 420
    • 421
    • 422
    • 423
    • 424
    • 425
    • 426
    • 427
    • 428
    • 429
    • 430
    • 431
    • 432
    • 433
    • 434
    • 435
    • 436
    • 437
    • 438
    • 439
    • 440
    • 441
    • 442
    • 443
    • 444
    • 445
    • 446
    • 447
    • 448
    • 449
    • 450
    • 451
    • 452
    • 453
    • 454
    • 455
    • 456
    • 457
    • 458
    • 459
    • 460
    • 461
    • 462
    • 463
    • 464
    • 465
    • 466
    • 467
    • 468
    • 469
    • 470
    • 471
    • 472
    • 473
    • 474
    • 475
    • 476
    • 477
    • 478
    • 479
    • 480
    • 481
    • 482
    • 483
    • 484
    • 485
    • 486
    • 487
    • 488
    • 489
    • 490
    • 491
    • 492
    • 493
    • 494
    • 495
    • 496
    • 497
    • 498
    • 499
    • 500
    • 501
    • 502
    • 503
    • 504
    • 505
    • 506
    • 507
    • 508
    • 509
    • 510
    • 511
    • 512
    • 513
    • 514
    • 515
    • 516
    • 517
    • 518
    • 519
    • 520
    • 521
    • 522
    • 523
    • 524
    • 525
    • 526
    • 527
    • 528
    • 529
    • 530
    • 531
    • 532
    • 533
    • 534
    • 535
    • 536
    • 537
    • 538
    • 539
    • 540
    • 541
    • 542
    • 543
    • 544
    • 545
    • 546
    • 547
    • 548
    • 549
    • 550
    • 551
    • 552
    • 553
    • 554
    • 555
    • 556
    • 557
    • 558
    • 559
    • 560
    • 561
    • 562
    • 563
    • 564
    • 565
    • 566
    • 567
    • 568
    • 569
    • 570
    • 571
    • 572
    • 573
    • 574
    • 575
    • 576
    • 577
    • 578
    • 579
    • 580
    • 581
    • 582
    • 583
    • 584
    • 585
    • 586
    • 587
    • 588
    • 589
    • 590
    • 591
    • 592
    • 593
    • 594
    • 595
    • 596
    • 597
    • 598
    • 599
    • 600
    • 601
    • 602
    • 603
    • 604
    • 605
    • 606
    • 607
    • 608
    • 609
    • 610
    • 611
    • 612
    • 613
    • 614
    • 615
    • 616
    • 617
    • 618
    • 619
    • 620
    • 621
    • 622
    • 623
    • 624
    • 625
    • 626
    • 627
    • 628
    • 629
    • 630
    • 631
    • 632
    • 633
    • 634
    • 635
    • 636
    • 637
    • 638
    • 639
    • 640
    • 641
    • 642
    • 643
    • 644
    • 645
    • 646
    • 647
    • 648
    • 649
    • 650
    • 651
    • 652
    • 653
    • 654
    • 655
    • 656
    • 657
    • 658
    • 659
    • 660
    • 661
    • 662
    • 663
    • 664
    • 665
    • 666
    • 667
    • 668
    • 669
    • 670
    • 671
    • 672
    • 673
    • 674
    • 675
    • 676
    • 677
    • 678
    • 679
    • 680
    • 681
    • 682
    • 683
    • 684
    • 685
    • 686
    • 687
    • 688
    • 689
    • 690
    • 691
    • 692
    • 693
    • 694
    • 695
    • 696
    • 697
    • 698
    • 699
    • 700
    • 701
    • 702
    • 703
    • 704
    • 705
    • 706
    • 707
    • 708
    • 709
    • 710
    • 711
    • 712
    • 713
    • 714
    • 715
    • 716
    • 717
    • 718
    • 719
    • 720
    • 721
    • 722
    • 723
    • 724
    • 725
    • 726
    • 727
    • 728
    • 729
    • 730
    • 731
    • 732
    • 733
    • 734
    • 735
    • 736
    • 737
    • 738
    • 739
    • 740
    • 741
    • 742
    • 743
    • 744
    • 745
    • 746
    • 747
    • 748
    • 749
    • 750
    • 751
    • 752
    • 753
    • 754
    • 755
    • 756
    • 757
    • 758
    • 759
    • 760
    • 761
    • 762
    • 763
    • 764
    • 765
    • 766
    • 767
    • 768
    • 769
    • 770
    • 771
    • 772
    • 773
    • 774
    • 775
    • 776
    • 777
    • 778
    • 779
    • 780
    • 781
    • 782
    • 783
    • 784
    • 785
    • 786
    • 787
    • 788
    • 789
    • 790
    • 791
    • 792
    • 793
    • 794
    • 795
    • 796
    • 797
    • 798
    • 799
    • 800
    • 801
    • 802
    • 803
    • 804
    • 805
    • 806
    • 807
    • 808
    • 809
    • 810
    • 811
    • 812
    • 813
    • 814
    • 815
    • 816
    • 817
    • 818
    • 819
    • 820
    • 821
    • 822
    • 823
    • 824
    • 825
    • 826
    • 827
    • 828
    • 829

    参考链接

    【CVPR2022】MPViT : Multi-Path Vision Transformer for Dense Prediction - 知乎 (zhihu.com)

    论文阅读:MPViT : Multi-Path Vision Transformer for Dense Prediction_甜橙不加冰的博客-CSDN博客

    【深度学习】语义分割:论文阅读:(CVPR 2022) MPViT(CNN+Transformer):用于密集预测的多路径视觉Transformer_sky_柘的博客-CSDN博客

  • 相关阅读:
    Flink -- window(窗口)
    线程池的异常处理机制
    基于SqlSugar的数据库访问处理的封装,支持.net FrameWork和.net core的项目调用
    2021-09-07-Cookie&&Session
    3个ui自动化测试痛点
    ubunu中配置torch环境4060显卡
    如何使用Python抓取PDF文件并自动下载到本地
    力扣labuladong——一刷day44
    FlashSpeech、ID-Animator、TalkingGaussian、FlowMap、CutDiffusion
    8-高精度计算(加法)
  • 原文地址:https://blog.csdn.net/wujing1_1/article/details/126085774