• Pytorch之MobileViT图像分类



    • 💂 个人主页:风间琉璃
    • 🤟 版权: 本文由【风间琉璃】原创、在CSDN首发、需要转载请联系博主
    • 💬 如果文章对你有帮助欢迎关注点赞收藏(一键三连)订阅专栏

    前言

    MobileViT是一种基于ViT(Vision Transformer)架构的轻量级视觉模型,旨在适用于移动设备和嵌入式系统。ViT是一种非常成功的深度学习模型,用于图像分类和其他计算机视觉任务,但通常需要大量的计算资源和参数。MobileViT的目标是在保持高性能的同时,减少模型的大小和计算需求,以便在移动设备上运行,据作者介绍,这是第一次基于轻量级CNN网络性能的轻量级ViT工作,性能SOTA。性能优于MobileNetV3、CrossviT等网络。


    一、Transformer存在的问题

    MobileVitV1是苹果公司2021年发表的一篇轻量型主干网络,它是CNNTransfomrer混合架构模型(CNN的轻量和高效+Transformer的自注意力机制和全局视野),这样的架构模型也是现在很多研究者们青睐的架构之一。

    Vision Transformer出现之后,人们发现Transfomrer也可以应用在计算机视觉领域,并且效果还是非常不错的。但是基于Transformer的网络模型存在着以下问题:

    参数多,算力要求高
    Transformer模型通常具有数十亿或数百亿个参数,这使得它们的模型文件非常大,不仅占用大量存储空间,而且在训练和部署过程中也需要更多的计算资源。

    缺少空间归纳偏置
    即纯Transformer对空间位置信息不敏感,但是,我们在进行视觉应用的时位置信息又比较重要,为了解决这个问题就引入了位置编码。

    归纳 (Induction) 是自然科学中常用的两大方法之一 (归纳与演绎,Induction & Deduction),指从一些例子中寻找共性、泛化,形成一个较通用的规则的过程。偏置 (Bias) 则是指对模型的偏好,以下展示了 4 种解释:

    ∙ \bullet 通俗理解:归纳偏置可以理解为,从现实生活中观察到的现象中归纳出一定的 规则 (heuristics),然后对模型做一定的约束,从而可以起到 “模型选择” 的作用,类似贝叶斯学习中的 “先验”。
    ∙ \bullet 西瓜书解释:机器学习算法在学习过程中对某种类型假设的偏好,称为归纳偏好。归纳偏好可以看作学习算法自身在一个庞大的假设空间中对假设进行选择的启发式或 “价值观”。
    ∙ \bullet 维基百科解释:如果学习器需要去预测 “其未遇到过的输入” 的结果时,则需要一些假设来帮助它做出选择。
    ∙ \bullet 广义解释:归纳偏置会促使学习算法优先考虑具有某些属性的解。

    深度神经网络偏好性地认为,层次化处理信息有更好效果;卷积神经网络认为信息具有空间局部性,可用滑动卷积共享权重的方式降低参数空间;循环神经网络则将时序信息纳入考虑,强调顺序重要性;图网络则认为中心节点与邻居节点的相似性会更好地引导信息流动。通常,模型容量 (capacity) 很大但 Inductive Bias 匮乏则容易过拟合 (overfitting),如 Transformer

    CNN的空间归纳偏差内容如下:

    CNN 的 归纳偏置(Inductive Bias)局部性 (Locality) 空间不变性 (Spatial Invariance) / 平移等效性 (Translation Equivariance),即空间位置上的元素 (Grid Elements) 的联系/相关性近大远小,以及空间平移的不变性 (Kernel 权重共享)。

    ⋆ \star locality:CNN是以滑动窗口的形式一点一点地在图片上进行卷积的,所以假设图片上相邻的区域会有相邻的特征,靠得越近的东西相关性越强;

    ⋆ \star translation equivariance(平移等变性或平移同变性):用公式表示为f(g(x))=g(f(x)),不论是先经过g映射,还是先经过f映射,其结果是不变的;其中f代表卷积操作,g代表平移操作。因为在卷积神经网络中,卷积核相当于是一个模板,不论图片中同样的物体移动到哪里,只要是相同的输入,经过相同的卷积核,其输出是不变的。

    一旦网络(CNN)模型有了这两个归纳偏置,它就拥有很多的先验信息所以只需要相对较少的数据就可以学习一个相对比较好的模型。但是对于transformer来说,它没有这些先验信息,所以它对视觉的感知全部需要从这些数据中自己学习。

    因此transformer结构的网络模型需要大量的数据才能得到不错的效果,如果使用少量数据进行训练,那么会掉点很明显。这是因为Transformer缺少空间归纳偏置,空间归纳偏置允许CNN在不同的视觉任务中学习较少参数的表示

    虽然Transformer缺少空间归纳偏置必须要大量数据来进行学习数据中的某种特性,从而导致无法很好的应用在这样的边缘设备。但是CNN也有缺点CNN在空间上获取的信息是局部的,因此一定程度上会制约着CNN网络结构的性能,而Transformer的自注意力机制能够获取全局信息。

    模型迁移困难

    这个问题核心是引入的位置编码导致的。 Transformer 网络需要先对原始的图像进行切片处理,一般来说训练好的 ViT 网络原始的输入图像大小 224×224,patch 大小为 16×16,那么得到的 patch个数也就固定了。由于 Transformer 网络缺少空间归纳偏置在计算某一个 token 时其他 token 位置顺序发生变化并不会影响到最终的实验结果,也即输出与位置信息无关。而我们知道对于图像来说,空间信息是非常重要且具有实际意义的,因此,Transformer 通过加上位置偏置ViT使用绝对位置偏置,Swin T引入相对位置偏置来解决位置信息的丢失问题

    但是,当输入图像的尺寸或者 patch 大小发生变化时,训练好的模型就会因为位置信息不准确而失效。目前常见的处理方法是将位置偏置信息进行插值,插值到所需要的序列长度从而匹配到图像的尺寸。这种方式需要对训练好的模型进行微调才能保证性能不出现大幅损失,每次改变输入图像的尺寸或者 patch 的尺寸均需要对位置编码进行插值和对网络进行微调,这提高了网络迁移的难度

    Swin T网络使用了相对位置偏置,理论上来说序列的长度只与窗 windows 的大小有关而与输入图像的尺寸无关。但是,windows的大小一般被设定与输入尺寸匹配,当输入尺寸变大时,window 的大小也应该相应的增大,那么所使用的相对位置偏置序列也应该增大,这也会导致上述问题。这些问题将导致 Transformer 网络迁移时比 CNN 网络迁移得更加困难和繁琐。

    模型训练困难

    根据现有的一些经验,Transformer相比CNN要更难训练。Transformer需要更多的训练数据需要迭代更多的epoch需要更大的正则项(L2正则)需要更多的数据增强(且对数据增强很敏感)。

    针对以上问题,采用CNN与Transformer的混合架构CNN能够提供空间归纳偏置所以可以解决位置偏置,而且加入CNN后能够加速网络的收敛,使网络训练过程更加的稳定

    二、MobileViT

    1.MobileViT网络结构

    🍓 Vision Transformer结构

    下图是MobileViT论文中绘制的Standard visual Transformer。首先将输入的图片划分成N个Patch,然后通过线性变化将每个Patch映射到一维向量中(Token),接着加上位置偏置信息(可学习参数),再通过一系列Transformer Block,最后通过一个全连接层得到最终预测输出。
    在这里插入图片描述
    首先将C,H,W的图片进行Patch处理成N个向量,然后经过线性层进行降低向量维度,再经过位置编码,然后再经过N个Transformer块,在通过class token来进行分类。

    这个Standard visual Transformer和前面文章中ViT有一点不同,这里没有class token,class token只是针对分类才加上去的,上面这个网络才是最标准的视觉ViT网络。

    由于VIT忽略了空间归纳偏差,所以它们需要更多的参数来学习视觉表征。此外,与CNN相比,VIT及其多种变体的优化性能不佳,这些模型对L2正则化很敏感,需要大量的数据增强以防止过拟合

    🍉MobileViT结构

    上面展示是标准视觉ViT模型,下面来看下本次介绍的重点:Mobile-ViT网路结构,如下图所示:
    在这里插入图片描述通过上图可以看到MobileViT主要由普通卷积MV2(MobiletNetV2中的Inverted Residual block),MobileViT block全局池化以及全连接层共同组成。

    其中,MobileViT块中的Convn × n表示一个标准的n × n卷积MV2指的是MobileNetv2块执行下采样的块用↓2标记

    2.MV2(MobileNet v2 block)

    MV2 块指MobileNet v2 block,是一个Inverted Residual Block(倒残差结构)。 在倒残差结构中,即特征图的维度是先升后降,据相关论文中描述说,更高的维度经过激活函数后,它损失的信息就会少一些。(注意倒残差结构中基本使用的都是ReLU6激活函数,但是最后一个1x1的卷积层使用的是线性激活函数)。具体网络结构如下图所示。
    在这里插入图片描述
    MobileViT结构图中标有向下箭头的MV2结构代表stride等于2的情况,即需要进行下采样

    🌼Residual Block(残差结构):
    ①1x1卷积降维
    ②3x3卷积
    ③1x1卷积升维
    🌻Inverted Residual Block(倒残差结构)
    ①1x1卷积升维
    ②3x3卷积DW
    ③1x1卷积降维

    3.MobileViT block

    MV2来源于mobilenetv2,所以Mobile-ViT的核心是MobileViT block模块。MobileViT block的结构如下图所示:
    在这里插入图片描述
    MobileViT Block旨在用更少的参数对输入张量中的局部全局信息进行建模。由上图可知MobileViT Block 整体由三部分组成分别为:Local representationsTransformers as Convolutions (global representations)Fusion

    大致流程:首先将特征图通过一个卷积核大小为nxn(代码中是3x3)的卷积层进行局部的特征建模,然后通过一个卷积核大小为1x1的卷积层调整通道数。接着通过Unfold -> Transformer -> Fold结构进行全局的特征建模,然后再通过一个卷积核大小为1x1的卷积层将通道数调整回原始大小。接着通过shortcut分支(在V2版本中将该捷径分支取消了)与原始输入特征图进行Concat拼接(沿通道channel方向拼接),最后再通过一个卷积核大小为nxn(代码中是3x3)的卷积层做特征融合得到输出

    Global representations它的具体计算过程如下图所示,
    在这里插入图片描述
    首先对特征图划分Patch(忽略了通道channels),图中的Patch大小为2x2,即每个Patch由4个Pixel组成。

    在进行Self-Attention计算的时候,每个Token(图中的每个Pixel或者说每个小颜色块)只和颜色相同的Token进行Attention,可以减少参数计算量。对于原始的Self-Attention计算每个Token是需要和所有的Token进行Self-Attention。

    假设特征图的高宽和通道数分别为H, W, C,在输入到Transformer中,在Self-Attention的时候,每个图中的每个像素和其他的像素进行计算,这样计算量就是:
    P 1 = W ∗ H ∗ C P_1 = W*H*C P1=WHC

    MobileViT中的是先对输入的特征图划分成多个的patch,但是在计算Self-Attention的时候只对相同位置的像素计算,即图中展示的颜色相同的位置,这样就可以相对的减少计算量,这个时候的计算量为:
    P 2 = W ∗ H ∗ C 4 P_2 = \frac{W*H*C}{4} P2=4WHC即理论上的计算成本仅为原始的 1 4 \frac{1}{4} 41

    在本次的自注意力机制中,只选择了位置相同的像素点进行点积操作。这样做的原因大概就是因为和所有的像素点都进行自注意力操作会带来信息冗余,毕竟不是所有的像素含有有用的信息对于图像数据本身就存在大量的数据冗余,一张图像的每个像素点的周围的像素值都差不多,并且分辨率越高相差越小,所以这样做并不会损失太多的信息。而且MobileViT在做全局表征之前已经做了一次局部表征(Local representations),进行全局建模时可以忽略一些信息。

    Global representations中的​UnfoldFold只是为了将数据给reshape成计算Self-Attention时所需的数据格式。unfold就是将颜色相同的部分拼成一个序列输入到Transformer进行建模,最后再通过fold是调整为原始大小,如下图所示:
    在这里插入图片描述
    下面来简单的看下patch size对模型性能的影响,patch如果划分的比较大的话是可以减少计算量的,但是划分的太大的话又会忽略更多的语义信息,影响模型的性能。

    下图从左到右对语义信息的要求逐渐递增。其中配置A的patch大小为{2, 2, 2},配置B的patch大小为{8, 4, 2},这三个数字分别对应下采样倍率为8,16,32的特征图所采用的patch大小。通过对比可以发现,在图像分类目标检测任务中(对语义细节要求不高的场景),配置A和配置B在Acc和mAP上没太大区别,但配置B要更快。但在语义分割任务中(对语义细节要求较高的场景)配置A的效果要更好。
    在这里插入图片描述

    🥇Local representations

    Local representations 表示输入信息的局部表达。在这个部分,输入MobileViT Block 的数据会经过一个 n × n n \times n n×n的卷积块和一个 1 × 1 1 \times 1 1×1的卷积块。

    从上文所述的CNN的空间归纳偏差就可以得知:经过 n × n n \times n n×n(n=3)的卷积块的输出获取到了输入模型的局部信息表达(因为卷积块是对一个整体块进行操作,但是这个卷积核的n是远远小于数据规模的,所以是局部信息表达,而不是全局信息表达)。另外, 1 × 1 1 \times 1 1×1的卷积块是为了线性投影将数据投影至高维空间。例如:对于 9 × 9 9\times 9 9×9的数据,使用 3 × 3 3\times 3 3×3的卷积层,获取到的每个数据都是对 9 × 9 9\times 9 9×9 数据的局部表达

    🥈Transformers as Convolutions (global representations)

    Transformers as Convolutions (global representations) 表示输入信息的全局表示。在Transformers as Convolutions 中首先通过Unfold 对数据进行转换,转化为 Transformer 可以接受的 1D 数据。然后将数据输入到Transformer 块中。最后通过Fold再将数据变换成原有的样子。

    🥉Fusion

    Fusion中,经过Transformers as Convolutions得到的信息原始输入信息 ( A ∈ R H × W × C ) (\mathrm{A} \in \mathrm{R^{\mathrm{H \times W \times C}}}) (ARH×W×C)进行合并,然后使用另一个 n × n n\times n n×n卷积层来融合这些连接的特征。这里,得到的信息指:全局表征 X F ∈ R H × W × d \mathrm{X_F} \in \mathrm{R^{\mathrm{H \times W \times d}}} XFRH×W×d经过逐点卷积( 1 × 1 1\times 1 1×1卷积)得到的输出 X F u ∈ R H × W × d \mathrm{X_{Fu}} \in \mathrm{R^{\mathrm{H \times W \times d}}} XFuRH×W×d ,并通过Concat操作与 X \mathrm{X} X组合。

    4.模型配置

    论文中总共给出了三组模型配置,即MobileViT-S(small)、MobileViT-XS(extra small)、MobileViT-XXS(extra extra small),三种配置是越来越轻量化,三者的主要区别在于特征图的通道数不同

    下图为MobileViT的整体框架,主要看下图中的标出的Layer1~5,这里是根据源码中的配置信息划分的:
    在这里插入图片描述
    对于MobileViT-XXS,Layer1~5的详细配置信息如下:
    在这里插入图片描述
    对于MobileViT-XS,Layer1~5的详细配置信息如下:
    在这里插入图片描述
    对于MobileViT-S,Layer1~5的详细配置信息如下:
    在这里插入图片描述
    参数说明:
    ⋆ \star out_channels表示该模块输出的通道数
    ⋆ \star mv2_exp表示Inverted Residual Block中的expansion ratio
    ⋆ \star transformer_channels表示Transformer模块输入Token的序列长度(特征图通道数)
    ⋆ \star num_heads表示多头自注意力机制中的head数
    ⋆ \star ffn_dim表示FFN中间层Token的序列长度
    ⋆ \star patch_h表示每个patch的高度
    ⋆ \star patch_w表示每个patch的宽度

    5.MobileViT优势

    🍄更好的性能: 对于给定的参数预算,MobileViT 在不同的移动视觉任务(图像分类、物体检测、语义分割)中取得了比现有的轻量级 CNN 更好的性能

    🍄更好的泛化能力泛化能力是指训练和评价指标之间的差距。对于具有相似训练指标的2个模型,具有更好评价指标的模型更具有通用性,因为它可以更好地预测未知数据集。与CNN相比,即使有广泛的数据增强,其泛化能力也很差,MobileViT显示出更好的泛化能力。

    🍄更好的鲁棒性:一个好的模型应该对超参数具有鲁棒性,因为调优这些超参数会消耗时间和资源。与大多数基于ViT的模型不同,MobileViT模型使用基本增强训练,对L2正则化不太敏感

    总之,MobileViT使用CNNTransformer相融合的方案,在减少模型复杂度的同时,提高了模型的精度和鲁棒性

    ⋆ \star 对于一个模型,如果全都使用 CNN 结构。模型只能获取到数据的局部信息而获取不到全局信息
    ⋆ \star 对于一个模型,如果全部使用 Transformer 结构。模型可以获取到全局信息。但是,Transformer 结构会带来较大的复杂度,存在训练时间上升,模型容易过拟合等等问题。

    因此,基于上述问题。作者先使用CNN获取局部信息,然后使用 Transformer 结构获取全局信息。通过上述的理解可以发现:在MobileViT 中的Transformer 结构中,复杂度相比于 ViT 结构 中复杂度降低了很多,因为输入数据复杂度的降低。最终实验结果同时表明:MobileViT 精度更高且鲁棒性更好

    三、MobileViT网络实现

    1.构建网络模型

    首先要构建MobileViT block,其结构图如下所示:
    在这里插入图片描述
    Transformer实现:
    在这里插入图片描述

    
    class MultiHeadAttention(nn.Module):
        """
        This layer applies a multi-head self- or cross-attention as described in
        `Attention is all you need `_ paper
    
        Args:
            embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(N, P, C_{in})`
            num_heads (int): Number of heads in multi-head attention
            attn_dropout (float): Attention dropout. Default: 0.0
            bias (bool): Use bias or not. Default: ``True``
    
        Shape:
            - Input: :math:`(N, P, C_{in})` where :math:`N` is batch size, :math:`P` is number of patches,
            and :math:`C_{in}` is input embedding dim
            - Output: same shape as the input
    
        """
    
        def __init__(
            self,
            embed_dim: int,
            num_heads: int,
            attn_dropout: float = 0.0,
            bias: bool = True,
            *args,
            **kwargs
        ) -> None:
            super().__init__()
            if embed_dim % num_heads != 0:
                raise ValueError(
                    "Embedding dim must be divisible by number of heads in {}. Got: embed_dim={} and num_heads={}".format(
                        self.__class__.__name__, embed_dim, num_heads
                    )
                )
    
            self.qkv_proj = nn.Linear(in_features=embed_dim, out_features=3 * embed_dim, bias=bias)
    
            self.attn_dropout = nn.Dropout(p=attn_dropout)
            self.out_proj = nn.Linear(in_features=embed_dim, out_features=embed_dim, bias=bias)
    
            self.head_dim = embed_dim // num_heads
            self.scaling = self.head_dim ** -0.5
            self.softmax = nn.Softmax(dim=-1)
            self.num_heads = num_heads
            self.embed_dim = embed_dim
    
        def forward(self, x_q: Tensor) -> Tensor:
            # [N, P, C]
            b_sz, n_patches, in_channels = x_q.shape
    
            # self-attention
            # [N, P, C] -> [N, P, 3C] -> [N, P, 3, h, c] where C = hc
            qkv = self.qkv_proj(x_q).reshape(b_sz, n_patches, 3, self.num_heads, -1)
    
            # [N, P, 3, h, c] -> [N, h, 3, P, C]
            qkv = qkv.transpose(1, 3).contiguous()
    
            # [N, h, 3, P, C] -> [N, h, P, C] x 3
            query, key, value = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
    
            query = query * self.scaling
    
            # [N h, P, c] -> [N, h, c, P]
            key = key.transpose(-1, -2)
    
            # QK^T
            # [N, h, P, c] x [N, h, c, P] -> [N, h, P, P]
            attn = torch.matmul(query, key)
            attn = self.softmax(attn)
            attn = self.attn_dropout(attn)
    
            # weighted sum
            # [N, h, P, P] x [N, h, P, c] -> [N, h, P, c]
            out = torch.matmul(attn, value)
    
            # [N, h, P, c] -> [N, P, h, c] -> [N, P, C]
            out = out.transpose(1, 2).reshape(b_sz, n_patches, -1)
            out = self.out_proj(out)
    
            return out
    
    
    class TransformerEncoder(nn.Module):
        """
        This class defines the pre-norm `Transformer encoder `_
        Args:
            embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(N, P, C_{in})`
            ffn_latent_dim (int): Inner dimension of the FFN
            num_heads (int) : Number of heads in multi-head attention. Default: 8
            attn_dropout (float): Dropout rate for attention in multi-head attention. Default: 0.0
            dropout (float): Dropout rate. Default: 0.0
            ffn_dropout (float): Dropout between FFN layers. Default: 0.0
    
        Shape:
            - Input: :math:`(N, P, C_{in})` where :math:`N` is batch size, :math:`P` is number of patches,
            and :math:`C_{in}` is input embedding dim
            - Output: same shape as the input
        """
    
        def __init__(
            self,
            embed_dim: int,
            ffn_latent_dim: int,
            num_heads: Optional[int] = 8,
            attn_dropout: Optional[float] = 0.0,
            dropout: Optional[float] = 0.0,
            ffn_dropout: Optional[float] = 0.0,
            *args,
            **kwargs
        ) -> None:
    
            super().__init__()
    
            attn_unit = MultiHeadAttention(
                embed_dim,
                num_heads,
                attn_dropout=attn_dropout,
                bias=True
            )
    
            self.pre_norm_mha = nn.Sequential(
                nn.LayerNorm(embed_dim),
                attn_unit,
                nn.Dropout(p=dropout)
            )
    
            self.pre_norm_ffn = nn.Sequential(
                nn.LayerNorm(embed_dim),
                nn.Linear(in_features=embed_dim, out_features=ffn_latent_dim, bias=True),
                nn.SiLU(),
                nn.Dropout(p=ffn_dropout),
                nn.Linear(in_features=ffn_latent_dim, out_features=embed_dim, bias=True),
                nn.Dropout(p=dropout)
            )
            self.embed_dim = embed_dim
            self.ffn_dim = ffn_latent_dim
            self.ffn_dropout = ffn_dropout
            self.std_dropout = dropout
    
        def forward(self, x: Tensor) -> Tensor:
            # multi-head attention
            res = x
            x = self.pre_norm_mha(x)
            x = x + res
    
            # feed forward network
            x = x + self.pre_norm_ffn(x)
            return x
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150

    MobileViT的整体框架,主要看下图中的标出的Layer1~5,这里是根据源码中的配置信息划分的:
    在这里插入图片描述

    def get_config(mode: str = "xxs") -> dict:
        if mode == "xx_small":
            mv2_exp_mult = 2
            config = {
                "layer1": {
                    "out_channels": 16,
                    "expand_ratio": mv2_exp_mult,
                    "num_blocks": 1,
                    "stride": 1,
                    "block_type": "mv2",
                },
                "layer2": {
                    "out_channels": 24,
                    "expand_ratio": mv2_exp_mult,
                    "num_blocks": 3,
                    "stride": 2,
                    "block_type": "mv2",
                },
                "layer3": {  # 28x28
                    "out_channels": 48,
                    "transformer_channels": 64,
                    "ffn_dim": 128,
                    "transformer_blocks": 2,
                    "patch_h": 2,  # 8,
                    "patch_w": 2,  # 8,
                    "stride": 2,
                    "mv_expand_ratio": mv2_exp_mult,
                    "num_heads": 4,
                    "block_type": "mobilevit",
                },
                "layer4": {  # 14x14
                    "out_channels": 64,
                    "transformer_channels": 80,
                    "ffn_dim": 160,
                    "transformer_blocks": 4,
                    "patch_h": 2,  # 4,
                    "patch_w": 2,  # 4,
                    "stride": 2,
                    "mv_expand_ratio": mv2_exp_mult,
                    "num_heads": 4,
                    "block_type": "mobilevit",
                },
                "layer5": {  # 7x7
                    "out_channels": 80,
                    "transformer_channels": 96,
                    "ffn_dim": 192,
                    "transformer_blocks": 3,
                    "patch_h": 2,
                    "patch_w": 2,
                    "stride": 2,
                    "mv_expand_ratio": mv2_exp_mult,
                    "num_heads": 4,
                    "block_type": "mobilevit",
                },
                "last_layer_exp_factor": 4,
                "cls_dropout": 0.1
            }
        elif mode == "x_small":
            mv2_exp_mult = 4
            config = {
                "layer1": {
                    "out_channels": 32,
                    "expand_ratio": mv2_exp_mult,
                    "num_blocks": 1,
                    "stride": 1,
                    "block_type": "mv2",
                },
                "layer2": {
                    "out_channels": 48,
                    "expand_ratio": mv2_exp_mult,
                    "num_blocks": 3,
                    "stride": 2,
                    "block_type": "mv2",
                },
                "layer3": {  # 28x28
                    "out_channels": 64,
                    "transformer_channels": 96,
                    "ffn_dim": 192,
                    "transformer_blocks": 2,
                    "patch_h": 2,
                    "patch_w": 2,
                    "stride": 2,
                    "mv_expand_ratio": mv2_exp_mult,
                    "num_heads": 4,
                    "block_type": "mobilevit",
                },
                "layer4": {  # 14x14
                    "out_channels": 80,
                    "transformer_channels": 120,
                    "ffn_dim": 240,
                    "transformer_blocks": 4,
                    "patch_h": 2,
                    "patch_w": 2,
                    "stride": 2,
                    "mv_expand_ratio": mv2_exp_mult,
                    "num_heads": 4,
                    "block_type": "mobilevit",
                },
                "layer5": {  # 7x7
                    "out_channels": 96,
                    "transformer_channels": 144,
                    "ffn_dim": 288,
                    "transformer_blocks": 3,
                    "patch_h": 2,
                    "patch_w": 2,
                    "stride": 2,
                    "mv_expand_ratio": mv2_exp_mult,
                    "num_heads": 4,
                    "block_type": "mobilevit",
                },
                "last_layer_exp_factor": 4,
                "cls_dropout": 0.1
            }
        elif mode == "small":
            mv2_exp_mult = 4
            config = {
                "layer1": {
                    "out_channels": 32,
                    "expand_ratio": mv2_exp_mult,
                    "num_blocks": 1,
                    "stride": 1,
                    "block_type": "mv2",
                },
                "layer2": {
                    "out_channels": 64,
                    "expand_ratio": mv2_exp_mult,
                    "num_blocks": 3,
                    "stride": 2,
                    "block_type": "mv2",
                },
                "layer3": {  # 28x28
                    "out_channels": 96,
                    "transformer_channels": 144,
                    "ffn_dim": 288,
                    "transformer_blocks": 2,
                    "patch_h": 2,
                    "patch_w": 2,
                    "stride": 2,
                    "mv_expand_ratio": mv2_exp_mult,
                    "num_heads": 4,
                    "block_type": "mobilevit",
                },
                "layer4": {  # 14x14
                    "out_channels": 128,
                    "transformer_channels": 192,
                    "ffn_dim": 384,
                    "transformer_blocks": 4,
                    "patch_h": 2,
                    "patch_w": 2,
                    "stride": 2,
                    "mv_expand_ratio": mv2_exp_mult,
                    "num_heads": 4,
                    "block_type": "mobilevit",
                },
                "layer5": {  # 7x7
                    "out_channels": 160,
                    "transformer_channels": 240,
                    "ffn_dim": 480,
                    "transformer_blocks": 3,
                    "patch_h": 2,
                    "patch_w": 2,
                    "stride": 2,
                    "mv_expand_ratio": mv2_exp_mult,
                    "num_heads": 4,
                    "block_type": "mobilevit",
                },
                "last_layer_exp_factor": 4,
                "cls_dropout": 0.1
            }
        else:
            raise NotImplementedError
    
        for k in ["layer1", "layer2", "layer3", "layer4", "layer5"]:
            config[k].update({"dropout": 0.1, "ffn_dropout": 0.0, "attn_dropout": 0.0})
    
        return config
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
    • 168
    • 169
    • 170
    • 171
    • 172
    • 173
    • 174
    • 175
    • 176
    • 177

    现在开始构建MobileVit网络模型

    def make_divisible(
        v: Union[float, int],
        divisor: Optional[int] = 8,
        min_value: Optional[Union[float, int]] = None,
    ) -> Union[float, int]:
        """
        This function is taken from the original tf repo.
        It ensures that all layers have a channel number that is divisible by 8
        It can be seen here:
        https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
        :param v:
        :param divisor:
        :param min_value:
        :return:
        """
        if min_value is None:
            min_value = divisor
        new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
        # Make sure that round down does not go down by more than 10%.
        if new_v < 0.9 * v:
            new_v += divisor
        return new_v
    
    # 卷积层
    class ConvLayer(nn.Module):
        """
        Applies a 2D convolution over an input
    
        Args:
            in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H_{in}, W_{in})`
            out_channels (int): :math:`C_{out}` from an expected output of size :math:`(N, C_{out}, H_{out}, W_{out})`
            kernel_size (Union[int, Tuple[int, int]]): Kernel size for convolution.
            stride (Union[int, Tuple[int, int]]): Stride for convolution. Default: 1
            groups (Optional[int]): Number of groups in convolution. Default: 1
            bias (Optional[bool]): Use bias. Default: ``False``
            use_norm (Optional[bool]): Use normalization layer after convolution. Default: ``True``
            use_act (Optional[bool]): Use activation layer after convolution (or convolution and normalization).
                                    Default: ``True``
    
        Shape:
            - Input: :math:`(N, C_{in}, H_{in}, W_{in})`
            - Output: :math:`(N, C_{out}, H_{out}, W_{out})`
    
        .. note::
            For depth-wise convolution, `groups=C_{in}=C_{out}`.
        """
    
        def __init__(
            self,
            in_channels: int,
            out_channels: int,
            kernel_size: Union[int, Tuple[int, int]],
            stride: Optional[Union[int, Tuple[int, int]]] = 1,
            groups: Optional[int] = 1,
            bias: Optional[bool] = False,
            use_norm: Optional[bool] = True,
            use_act: Optional[bool] = True,
        ) -> None:
            super().__init__()
    
            if isinstance(kernel_size, int):
                kernel_size = (kernel_size, kernel_size)
    
            if isinstance(stride, int):
                stride = (stride, stride)
    
            assert isinstance(kernel_size, Tuple)
            assert isinstance(stride, Tuple)
    
            padding = (
                int((kernel_size[0] - 1) / 2),
                int((kernel_size[1] - 1) / 2),
            )
    
            block = nn.Sequential()
    
            conv_layer = nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=stride,
                groups=groups,
                padding=padding,
                bias=bias
            )
    
            block.add_module(name="conv", module=conv_layer)
    
            if use_norm:
                norm_layer = nn.BatchNorm2d(num_features=out_channels, momentum=0.1)
                block.add_module(name="norm", module=norm_layer)
    
            if use_act:
                act_layer = nn.SiLU()
                block.add_module(name="act", module=act_layer)
    
            self.block = block
    
        def forward(self, x: Tensor) -> Tensor:
            return self.block(x)
    
    # MV2
    class InvertedResidual(nn.Module):
        """
        This class implements the inverted residual block, as described in `MobileNetv2 `_ paper
    
        Args:
            in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H_{in}, W_{in})`
            out_channels (int): :math:`C_{out}` from an expected output of size :math:`(N, C_{out}, H_{out}, W_{out)`
            stride (int): Use convolutions with a stride. Default: 1
            expand_ratio (Union[int, float]): Expand the input channels by this factor in depth-wise conv
            skip_connection (Optional[bool]): Use skip-connection. Default: True
    
        Shape:
            - Input: :math:`(N, C_{in}, H_{in}, W_{in})`
            - Output: :math:`(N, C_{out}, H_{out}, W_{out})`
    
        .. note::
            If `in_channels =! out_channels` and `stride > 1`, we set `skip_connection=False`
    
        """
    
        def __init__(
            self,
            in_channels: int,
            out_channels: int,
            stride: int,
            expand_ratio: Union[int, float],
            skip_connection: Optional[bool] = True,
        ) -> None:
            assert stride in [1, 2]
            hidden_dim = make_divisible(int(round(in_channels * expand_ratio)), 8)
    
            super().__init__()
    
            block = nn.Sequential()
            if expand_ratio != 1:
                block.add_module(
                    name="exp_1x1",
                    module=ConvLayer(
                        in_channels=in_channels,
                        out_channels=hidden_dim,
                        kernel_size=1
                    ),
                )
    
            block.add_module(
                name="conv_3x3",
                module=ConvLayer(
                    in_channels=hidden_dim,
                    out_channels=hidden_dim,
                    stride=stride,
                    kernel_size=3,
                    groups=hidden_dim
                ),
            )
    
            block.add_module(
                name="red_1x1",
                module=ConvLayer(
                    in_channels=hidden_dim,
                    out_channels=out_channels,
                    kernel_size=1,
                    use_act=False,
                    use_norm=True,
                ),
            )
    
            self.block = block
            self.in_channels = in_channels
            self.out_channels = out_channels
            self.exp = expand_ratio
            self.stride = stride
            self.use_res_connect = (
                self.stride == 1 and in_channels == out_channels and skip_connection
            )
    
        def forward(self, x: Tensor, *args, **kwargs) -> Tensor:
            if self.use_res_connect:
                return x + self.block(x)
            else:
                return self.block(x)
    
    
    class MobileViTBlock(nn.Module):
        """
        This class defines the `MobileViT block `_
    
        Args:
            opts: command line arguments
            in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H, W)`
            transformer_dim (int): Input dimension to the transformer unit
            ffn_dim (int): Dimension of the FFN block
            n_transformer_blocks (int): Number of transformer blocks. Default: 2
            head_dim (int): Head dimension in the multi-head attention. Default: 32
            attn_dropout (float): Dropout in multi-head attention. Default: 0.0
            dropout (float): Dropout rate. Default: 0.0
            ffn_dropout (float): Dropout between FFN layers in transformer. Default: 0.0
            patch_h (int): Patch height for unfolding operation. Default: 8
            patch_w (int): Patch width for unfolding operation. Default: 8
            transformer_norm_layer (Optional[str]): Normalization layer in the transformer block. Default: layer_norm
            conv_ksize (int): Kernel size to learn local representations in MobileViT block. Default: 3
            no_fusion (Optional[bool]): Do not combine the input and output feature maps. Default: False
        """
    
        def __init__(
            self,
            in_channels: int,
            transformer_dim: int,
            ffn_dim: int,
            n_transformer_blocks: int = 2,
            head_dim: int = 32,
            attn_dropout: float = 0.0,
            dropout: float = 0.0,
            ffn_dropout: float = 0.0,
            patch_h: int = 8,
            patch_w: int = 8,
            conv_ksize: Optional[int] = 3,
            *args,
            **kwargs
        ) -> None:
            super().__init__()
    
            # 下面两个卷积层:Local representations
            conv_3x3_in = ConvLayer(
                in_channels=in_channels,
                out_channels=in_channels,
                kernel_size=conv_ksize,
                stride=1
            )
            conv_1x1_in = ConvLayer(
                in_channels=in_channels,
                out_channels=transformer_dim,
                kernel_size=1,
                stride=1,
                use_norm=False,
                use_act=False
            )
    
            # 下面两个卷积层:Fusion
            conv_1x1_out = ConvLayer(
                in_channels=transformer_dim,
                out_channels=in_channels,
                kernel_size=1,
                stride=1
            )
            conv_3x3_out = ConvLayer(
                in_channels=2 * in_channels,
                out_channels=in_channels,
                kernel_size=conv_ksize,
                stride=1
            )
    
            # Local representations
            self.local_rep = nn.Sequential()
            self.local_rep.add_module(name="conv_3x3", module=conv_3x3_in)
            self.local_rep.add_module(name="conv_1x1", module=conv_1x1_in)
    
            assert transformer_dim % head_dim == 0
            num_heads = transformer_dim // head_dim
    
            # global representations
            global_rep = [
                TransformerEncoder(
                    embed_dim=transformer_dim,
                    ffn_latent_dim=ffn_dim,
                    num_heads=num_heads,
                    attn_dropout=attn_dropout,
                    dropout=dropout,
                    ffn_dropout=ffn_dropout
                )
                for _ in range(n_transformer_blocks)
            ]
            global_rep.append(nn.LayerNorm(transformer_dim))
            self.global_rep = nn.Sequential(*global_rep)
    
            # Fusion
            self.conv_proj = conv_1x1_out
            self.fusion = conv_3x3_out
    
            self.patch_h = patch_h
            self.patch_w = patch_w
            self.patch_area = self.patch_w * self.patch_h
    
            self.cnn_in_dim = in_channels
            self.cnn_out_dim = transformer_dim
            self.n_heads = num_heads
            self.ffn_dim = ffn_dim
            self.dropout = dropout
            self.attn_dropout = attn_dropout
            self.ffn_dropout = ffn_dropout
            self.n_blocks = n_transformer_blocks
            self.conv_ksize = conv_ksize
    
        def unfolding(self, x: Tensor) -> Tuple[Tensor, Dict]:
            patch_w, patch_h = self.patch_w, self.patch_h
            patch_area = patch_w * patch_h
            batch_size, in_channels, orig_h, orig_w = x.shape
    
            new_h = int(math.ceil(orig_h / self.patch_h) * self.patch_h)
            new_w = int(math.ceil(orig_w / self.patch_w) * self.patch_w)
    
            interpolate = False
            if new_w != orig_w or new_h != orig_h:
                # Note: Padding can be done, but then it needs to be handled in attention function.
                x = F.interpolate(x, size=(new_h, new_w), mode="bilinear", align_corners=False)
                interpolate = True
    
            # number of patches along width and height
            num_patch_w = new_w // patch_w  # n_w
            num_patch_h = new_h // patch_h  # n_h
            num_patches = num_patch_h * num_patch_w  # N
    
            # [B, C, H, W] -> [B * C * n_h, p_h, n_w, p_w]
            x = x.reshape(batch_size * in_channels * num_patch_h, patch_h, num_patch_w, patch_w)
            # [B * C * n_h, p_h, n_w, p_w] -> [B * C * n_h, n_w, p_h, p_w]
            x = x.transpose(1, 2)
            # [B * C * n_h, n_w, p_h, p_w] -> [B, C, N, P] where P = p_h * p_w and N = n_h * n_w
            x = x.reshape(batch_size, in_channels, num_patches, patch_area)
            # [B, C, N, P] -> [B, P, N, C]
            x = x.transpose(1, 3)
            # [B, P, N, C] -> [BP, N, C]
            x = x.reshape(batch_size * patch_area, num_patches, -1)
    
            info_dict = {
                "orig_size": (orig_h, orig_w),
                "batch_size": batch_size,
                "interpolate": interpolate,
                "total_patches": num_patches,
                "num_patches_w": num_patch_w,
                "num_patches_h": num_patch_h,
            }
    
            return x, info_dict
    
        def folding(self, x: Tensor, info_dict: Dict) -> Tensor:
            n_dim = x.dim()
            assert n_dim == 3, "Tensor should be of shape BPxNxC. Got: {}".format(
                x.shape
            )
            # [BP, N, C] --> [B, P, N, C]
            x = x.contiguous().view(
                info_dict["batch_size"], self.patch_area, info_dict["total_patches"], -1
            )
    
            batch_size, pixels, num_patches, channels = x.size()
            num_patch_h = info_dict["num_patches_h"]
            num_patch_w = info_dict["num_patches_w"]
    
            # [B, P, N, C] -> [B, C, N, P]
            x = x.transpose(1, 3)
            # [B, C, N, P] -> [B*C*n_h, n_w, p_h, p_w]
            x = x.reshape(batch_size * channels * num_patch_h, num_patch_w, self.patch_h, self.patch_w)
            # [B*C*n_h, n_w, p_h, p_w] -> [B*C*n_h, p_h, n_w, p_w]
            x = x.transpose(1, 2)
            # [B*C*n_h, p_h, n_w, p_w] -> [B, C, H, W]
            x = x.reshape(batch_size, channels, num_patch_h * self.patch_h, num_patch_w * self.patch_w)
            if info_dict["interpolate"]:
                x = F.interpolate(
                    x,
                    size=info_dict["orig_size"],
                    mode="bilinear",
                    align_corners=False,
                )
            return x
    
        def forward(self, x: Tensor) -> Tensor:
            res = x
    
            fm = self.local_rep(x)
    
            # convert feature map to patches
            patches, info_dict = self.unfolding(fm)
    
            # learn global representations
            for transformer_layer in self.global_rep:
                patches = transformer_layer(patches)
    
            # [B x Patch x Patches x C] -> [B x C x Patches x Patch]
            fm = self.folding(x=patches, info_dict=info_dict)
    
            fm = self.conv_proj(fm)
    
            fm = self.fusion(torch.cat((res, fm), dim=1))
            return fm
    
    
    class MobileViT(nn.Module):
        """
        This class implements the `MobileViT architecture `_
        """
        def __init__(self, model_cfg: Dict, num_classes: int = 1000):
            super().__init__()
    
            image_channels = 3
            out_channels = 16
    
            self.conv_1 = ConvLayer(
                in_channels=image_channels,
                out_channels=out_channels,
                kernel_size=3,
                stride=2
            )
    
            self.layer_1, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer1"])
            self.layer_2, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer2"])
            self.layer_3, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer3"])
            self.layer_4, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer4"])
            self.layer_5, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer5"])
    
            exp_channels = min(model_cfg["last_layer_exp_factor"] * out_channels, 960)
            self.conv_1x1_exp = ConvLayer(
                in_channels=out_channels,
                out_channels=exp_channels,
                kernel_size=1
            )
    
            self.classifier = nn.Sequential()
            self.classifier.add_module(name="global_pool", module=nn.AdaptiveAvgPool2d(1))
            self.classifier.add_module(name="flatten", module=nn.Flatten())
            if 0.0 < model_cfg["cls_dropout"] < 1.0:
                self.classifier.add_module(name="dropout", module=nn.Dropout(p=model_cfg["cls_dropout"]))
            self.classifier.add_module(name="fc", module=nn.Linear(in_features=exp_channels, out_features=num_classes))
    
            # weight init
            self.apply(self.init_parameters)
    
        def _make_layer(self, input_channel, cfg: Dict) -> Tuple[nn.Sequential, int]:
            block_type = cfg.get("block_type", "mobilevit")
            if block_type.lower() == "mobilevit":
                return self._make_mit_layer(input_channel=input_channel, cfg=cfg)
            else:
                return self._make_mobilenet_layer(input_channel=input_channel, cfg=cfg)
    
        @staticmethod
        def _make_mobilenet_layer(input_channel: int, cfg: Dict) -> Tuple[nn.Sequential, int]:
            output_channels = cfg.get("out_channels")
            num_blocks = cfg.get("num_blocks", 2)
            expand_ratio = cfg.get("expand_ratio", 4)
            block = []
    
            for i in range(num_blocks):
                stride = cfg.get("stride", 1) if i == 0 else 1
    
                layer = InvertedResidual(
                    in_channels=input_channel,
                    out_channels=output_channels,
                    stride=stride,
                    expand_ratio=expand_ratio
                )
                block.append(layer)
                input_channel = output_channels
    
            return nn.Sequential(*block), input_channel
    
        @staticmethod
        def _make_mit_layer(input_channel: int, cfg: Dict) -> [nn.Sequential, int]:
            stride = cfg.get("stride", 1)
            block = []
    
            if stride == 2:
                layer = InvertedResidual(
                    in_channels=input_channel,
                    out_channels=cfg.get("out_channels"),
                    stride=stride,
                    expand_ratio=cfg.get("mv_expand_ratio", 4)
                )
    
                block.append(layer)
                input_channel = cfg.get("out_channels")
    
            transformer_dim = cfg["transformer_channels"]
            ffn_dim = cfg.get("ffn_dim")
            num_heads = cfg.get("num_heads", 4)
            head_dim = transformer_dim // num_heads
    
            if transformer_dim % head_dim != 0:
                raise ValueError("Transformer input dimension should be divisible by head dimension. "
                                 "Got {} and {}.".format(transformer_dim, head_dim))
    
            block.append(MobileViTBlock(
                in_channels=input_channel,
                transformer_dim=transformer_dim,
                ffn_dim=ffn_dim,
                n_transformer_blocks=cfg.get("transformer_blocks", 1),
                patch_h=cfg.get("patch_h", 2),
                patch_w=cfg.get("patch_w", 2),
                dropout=cfg.get("dropout", 0.1),
                ffn_dropout=cfg.get("ffn_dropout", 0.0),
                attn_dropout=cfg.get("attn_dropout", 0.1),
                head_dim=head_dim,
                conv_ksize=3
            ))
    
            return nn.Sequential(*block), input_channel
    
        @staticmethod
        def init_parameters(m):
            if isinstance(m, nn.Conv2d):
                if m.weight is not None:
                    nn.init.kaiming_normal_(m.weight, mode="fan_out")
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
                if m.weight is not None:
                    nn.init.ones_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, (nn.Linear,)):
                if m.weight is not None:
                    nn.init.trunc_normal_(m.weight, mean=0.0, std=0.02)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            else:
                pass
    
        def forward(self, x: Tensor) -> Tensor:
            x = self.conv_1(x)
            x = self.layer_1(x)
            x = self.layer_2(x)
    
            x = self.layer_3(x)
            x = self.layer_4(x)
            x = self.layer_5(x)
            x = self.conv_1x1_exp(x)
            x = self.classifier(x)
            return x
    
    
    def mobile_vit_xx_small(num_classes: int = 1000):
        # pretrain weight link
        # https://docs-assets.developer.apple.com/ml-research/models/cvnets/classification/mobilevit_xxs.pt
        config = get_config("xx_small")
        m = MobileViT(config, num_classes=num_classes)
        return m
    
    
    def mobile_vit_x_small(num_classes: int = 1000):
        # pretrain weight link
        # https://docs-assets.developer.apple.com/ml-research/models/cvnets/classification/mobilevit_xs.pt
        config = get_config("x_small")
        m = MobileViT(config, num_classes=num_classes)
        return m
    
    
    def mobile_vit_small(num_classes: int = 1000):
        # pretrain weight link
        # https://docs-assets.developer.apple.com/ml-research/models/cvnets/classification/mobilevit_s.pt
        config = get_config("small")
        m = MobileViT(config, num_classes=num_classes)
        return m
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
    • 168
    • 169
    • 170
    • 171
    • 172
    • 173
    • 174
    • 175
    • 176
    • 177
    • 178
    • 179
    • 180
    • 181
    • 182
    • 183
    • 184
    • 185
    • 186
    • 187
    • 188
    • 189
    • 190
    • 191
    • 192
    • 193
    • 194
    • 195
    • 196
    • 197
    • 198
    • 199
    • 200
    • 201
    • 202
    • 203
    • 204
    • 205
    • 206
    • 207
    • 208
    • 209
    • 210
    • 211
    • 212
    • 213
    • 214
    • 215
    • 216
    • 217
    • 218
    • 219
    • 220
    • 221
    • 222
    • 223
    • 224
    • 225
    • 226
    • 227
    • 228
    • 229
    • 230
    • 231
    • 232
    • 233
    • 234
    • 235
    • 236
    • 237
    • 238
    • 239
    • 240
    • 241
    • 242
    • 243
    • 244
    • 245
    • 246
    • 247
    • 248
    • 249
    • 250
    • 251
    • 252
    • 253
    • 254
    • 255
    • 256
    • 257
    • 258
    • 259
    • 260
    • 261
    • 262
    • 263
    • 264
    • 265
    • 266
    • 267
    • 268
    • 269
    • 270
    • 271
    • 272
    • 273
    • 274
    • 275
    • 276
    • 277
    • 278
    • 279
    • 280
    • 281
    • 282
    • 283
    • 284
    • 285
    • 286
    • 287
    • 288
    • 289
    • 290
    • 291
    • 292
    • 293
    • 294
    • 295
    • 296
    • 297
    • 298
    • 299
    • 300
    • 301
    • 302
    • 303
    • 304
    • 305
    • 306
    • 307
    • 308
    • 309
    • 310
    • 311
    • 312
    • 313
    • 314
    • 315
    • 316
    • 317
    • 318
    • 319
    • 320
    • 321
    • 322
    • 323
    • 324
    • 325
    • 326
    • 327
    • 328
    • 329
    • 330
    • 331
    • 332
    • 333
    • 334
    • 335
    • 336
    • 337
    • 338
    • 339
    • 340
    • 341
    • 342
    • 343
    • 344
    • 345
    • 346
    • 347
    • 348
    • 349
    • 350
    • 351
    • 352
    • 353
    • 354
    • 355
    • 356
    • 357
    • 358
    • 359
    • 360
    • 361
    • 362
    • 363
    • 364
    • 365
    • 366
    • 367
    • 368
    • 369
    • 370
    • 371
    • 372
    • 373
    • 374
    • 375
    • 376
    • 377
    • 378
    • 379
    • 380
    • 381
    • 382
    • 383
    • 384
    • 385
    • 386
    • 387
    • 388
    • 389
    • 390
    • 391
    • 392
    • 393
    • 394
    • 395
    • 396
    • 397
    • 398
    • 399
    • 400
    • 401
    • 402
    • 403
    • 404
    • 405
    • 406
    • 407
    • 408
    • 409
    • 410
    • 411
    • 412
    • 413
    • 414
    • 415
    • 416
    • 417
    • 418
    • 419
    • 420
    • 421
    • 422
    • 423
    • 424
    • 425
    • 426
    • 427
    • 428
    • 429
    • 430
    • 431
    • 432
    • 433
    • 434
    • 435
    • 436
    • 437
    • 438
    • 439
    • 440
    • 441
    • 442
    • 443
    • 444
    • 445
    • 446
    • 447
    • 448
    • 449
    • 450
    • 451
    • 452
    • 453
    • 454
    • 455
    • 456
    • 457
    • 458
    • 459
    • 460
    • 461
    • 462
    • 463
    • 464
    • 465
    • 466
    • 467
    • 468
    • 469
    • 470
    • 471
    • 472
    • 473
    • 474
    • 475
    • 476
    • 477
    • 478
    • 479
    • 480
    • 481
    • 482
    • 483
    • 484
    • 485
    • 486
    • 487
    • 488
    • 489
    • 490
    • 491
    • 492
    • 493
    • 494
    • 495
    • 496
    • 497
    • 498
    • 499
    • 500
    • 501
    • 502
    • 503
    • 504
    • 505
    • 506
    • 507
    • 508
    • 509
    • 510
    • 511
    • 512
    • 513
    • 514
    • 515
    • 516
    • 517
    • 518
    • 519
    • 520
    • 521
    • 522
    • 523
    • 524
    • 525
    • 526
    • 527
    • 528
    • 529
    • 530
    • 531
    • 532
    • 533
    • 534
    • 535
    • 536
    • 537
    • 538
    • 539
    • 540
    • 541
    • 542
    • 543
    • 544
    • 545
    • 546
    • 547
    • 548
    • 549
    • 550
    • 551
    • 552

    2.训练和测试模型

    def main(args):
        device = torch.device(args.device if torch.cuda.is_available() else "cpu")
    
        if os.path.exists("./weights") is False:
            os.makedirs("./weights")
    
        tb_writer = SummaryWriter()
    
        train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)
    
        img_size = 224
        data_transform = {
            "train": transforms.Compose([transforms.RandomResizedCrop(img_size),
                                         transforms.RandomHorizontalFlip(),
                                         transforms.ToTensor(),
                                         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
            "val": transforms.Compose([transforms.Resize(int(img_size * 1.143)),
                                       transforms.CenterCrop(img_size),
                                       transforms.ToTensor(),
                                       transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
    
        # 实例化训练数据集
        train_dataset = MyDataSet(images_path=train_images_path,
                                  images_class=train_images_label,
                                  transform=data_transform["train"])
    
        # 实例化验证数据集
        val_dataset = MyDataSet(images_path=val_images_path,
                                images_class=val_images_label,
                                transform=data_transform["val"])
    
        batch_size = args.batch_size
        nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
        print('Using {} dataloader workers every process'.format(nw))
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=batch_size,
                                                   shuffle=True,
                                                   pin_memory=True,
                                                   num_workers=nw,
                                                   collate_fn=train_dataset.collate_fn)
    
        val_loader = torch.utils.data.DataLoader(val_dataset,
                                                 batch_size=batch_size,
                                                 shuffle=False,
                                                 pin_memory=True,
                                                 num_workers=nw,
                                                 collate_fn=val_dataset.collate_fn)
    
        model = create_model(num_classes=args.num_classes).to(device)
    
        if args.weights != "":
            assert os.path.exists(args.weights), "weights file: '{}' not exist.".format(args.weights)
            weights_dict = torch.load(args.weights, map_location=device)
            weights_dict = weights_dict["model"] if "model" in weights_dict else weights_dict
            # 删除有关分类类别的权重
            for k in list(weights_dict.keys()):
                if "classifier" in k:
                    del weights_dict[k]
            print(model.load_state_dict(weights_dict, strict=False))
    
        if args.freeze_layers:
            for name, para in model.named_parameters():
                # 除head外,其他权重全部冻结
                if "classifier" not in name:
                    para.requires_grad_(False)
                else:
                    print("training {}".format(name))
    
        pg = [p for p in model.parameters() if p.requires_grad]
        optimizer = optim.AdamW(pg, lr=args.lr, weight_decay=1E-2)
    
        best_acc = 0.
        for epoch in range(args.epochs):
            # train
            train_loss, train_acc = train_one_epoch(model=model,
                                                    optimizer=optimizer,
                                                    data_loader=train_loader,
                                                    device=device,
                                                    epoch=epoch)
    
            # validate
            val_loss, val_acc = evaluate(model=model,
                                         data_loader=val_loader,
                                         device=device,
                                         epoch=epoch)
    
            tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"]
            tb_writer.add_scalar(tags[0], train_loss, epoch)
            tb_writer.add_scalar(tags[1], train_acc, epoch)
            tb_writer.add_scalar(tags[2], val_loss, epoch)
            tb_writer.add_scalar(tags[3], val_acc, epoch)
            tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch)
    
            if val_acc > best_acc:
                best_acc = val_acc
                torch.save(model.state_dict(), "./weights/best_model.pth")
    
            torch.save(model.state_dict(), "./weights/latest_model.pth")
    
    
    if __name__ == '__main__':
        parser = argparse.ArgumentParser()
        parser.add_argument('--num_classes', type=int, default=5)
        parser.add_argument('--epochs', type=int, default=50)
        parser.add_argument('--batch-size', type=int, default=8)
        parser.add_argument('--lr', type=float, default=0.0002)
    
        # 数据集所在根目录
        # https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
        parser.add_argument('--data-path', type=str,
                            default="F:/NN/Learn_Pytorch/flower_photos")
    
        # 预训练权重路径,如果不想载入就设置为空字符
        parser.add_argument('--weights', type=str, default='./mobilevit_xxs.pt',
                            help='initial weights path')
        # 是否冻结权重
        parser.add_argument('--freeze-layers', type=bool, default=False)
        parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')
    
        opt = parser.parse_args()
    
        main(opt)
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123

    这里使用了预训练权重,在其基础上训练自己的数据集。训练50epoch的准确率能到达94%左右。
    在这里插入图片描述

    四、图像分类

    这里使用花朵数据集,下载连接:https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz

    
    def main():
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
        img_size = 224
        data_transform = transforms.Compose(
            [transforms.Resize(int(img_size * 1.14)),
             transforms.CenterCrop(img_size),
             transforms.ToTensor(),
             transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
    
        # 加载图片
        img_path = 'daisy2.jpg'
        assert os.path.exists(img_path), "file: '{}' does not exist.".format(img_path)
        image = Image.open(img_path)
    
        # image.show()
        # [N, C, H, W]
        img = data_transform(image)
        # 扩展维度
        img = torch.unsqueeze(img, dim=0)
    
        # 获取标签
        json_path = 'class_indices.json'
        assert os.path.exists(json_path), "file: '{}' does not exist.".format(json_path)
        with open(json_path, 'r') as f:
            # 使用json.load()函数加载JSON文件的内容并将其存储在一个Python字典中
            class_indict = json.load(f)
    
        # create model
        model = create_model(num_classes=5).to(device)
        # load model weights
        model_weight_path = "./weights/best_model.pth"
        model.load_state_dict(torch.load(model_weight_path, map_location=device))
    
        model.eval()
        with torch.no_grad():
            # 对输入图像进行预测
            output = torch.squeeze(model(img.to(device))).cpu()
            # 对模型的输出进行 softmax 操作,将输出转换为类别概率
            predict = torch.softmax(output, dim=0)
            # 得到高概率的类别的索引
            predict_cla = torch.argmax(predict).numpy()
    
        res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)], predict[predict_cla].numpy())
        draw = ImageDraw.Draw(image)
        # 文本的左上角位置
        position = (10, 10)
        # fill 指定文本颜色
        draw.text(position, res, fill='green')
        image.show()
        for i in range(len(predict)):
            print("class: {:10}   prob: {:.3}".format(class_indict[str(i)], predict[i].numpy()))
    
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55

    预测结果:
    在这里插入图片描述

    结束语

    感谢阅读吾之文章,今已至此次旅程之终站 🛬。

    吾望斯文献能供尔以宝贵之信息与知识也 🎉。

    学习者之途,若藏于天际之星辰🍥,吾等皆当努力熠熠生辉,持续前行。

    然而,如若斯文献有益于尔,何不以三连为礼?点赞、留言、收藏 - 此等皆以证尔对作者之支持与鼓励也 💞。

  • 相关阅读:
    Python基础入门篇【27】--python基础入门练习卷C
    【树莓派不吃灰】基础篇⑱ 从0到1搭建docker环境,顺便安装一下emqx MQTT Broker、HomeAssistant、portainer
    css3中有哪些伪选择器?
    每天一个知识点-如何保证缓存一致性
    Linux学习笔记14 - 多线程编程(一)
    网络文化经营许可证这样办,省时又便捷!
    Matlab论文插图绘制模板第67期—三角网格图(Trimesh)
    php安装ldap扩展模块
    学习C++第二课
    Github 2024-07-11 开源项目日报 Top10
  • 原文地址:https://blog.csdn.net/qq_53144843/article/details/133621241