• YOLO算法改进Backbone系列之:EfficientViT


    EfficientViT: Memory Effificient Vision Transformer with Cascaded Group Attention
    摘要:视觉transformer由于其高模型能力而取得了巨大的成功。然而,它们卓越的性能伴随着沉重的计算成本,这使得它们不适合实时应用。在这篇论文中,本文提出了一个高速视觉transformer家族,名为EfficientViT。本文发现现有的transformer模型的速度通常受到内存低效操作的限制,特别是在MHSA中的张量重塑和单元函数。因此,本文设计了一种具有三明治布局的新构建块,即在高效FFN层之间使用单个内存绑定的MHSA,从而提高了内存效率,同时增强了信道通信。此外,本文发现注意图在头部之间具有很高的相似性,从而导致计算冗余。为了解决这个问题,本文提出了一个级联的群体注意模块,以不同的完整特征分割来馈送注意头,不仅节省了计算成本,而且提高了注意多样性。综合实验表明,高效vit优于现有的高效模型,在速度和精度之间取得了良好的平衡。例如,本文的EfficientViT-M5在准确率上比MobileNetV3-Large高出1.9%,而在Nvidia V100 GPU和Intel Xeon CPU上的吞吐量分别高出40.4%和45.2%。与最近的高效型号MobileViT-XXS相比,efficientvitt - m2的精度提高了1.8%,同时在GPU/CPU上运行速度提高了5.8 ×/3.7 ×,转换为ONNX格式时速度提高了7.4×

    本文通过分析DeiT和Swin两个Transformer架构得出如下结论:

    • 适当降低MHSA层利用率可以在提高模型性能的同时提高访存效率
    • 在不同的头部使用不同的通道划分特征,而不是像MHSA那样对所有头部使用相同的全特征,可以有效地减少注意力计算冗余
    • 典型的通道配置,即在每个阶段之后将通道数加倍或对所有块使用等效通道,可能在最后几个块中产生大量冗余
    • 在维度相同的情况下,Q、K的冗余度比V大得多 a new building block with a sandwich
      layout(减少self-attention的次数):之前是一个block self-attention->fc->self-attention->fc->self-attention->fc->…N次数;现在是一个blockfc->self-attention->fc;不仅能够提升内存效率而且能够增强通道间的计算
      cascaded group attention:让多头串联学习特征:第一个头学习完特征后,第二个头利用第一个头学习到的特征的基础上再去学习(原来的transformer是第二个头跟第一个头同时独立地去学习),同理第三个头学习时也得利用上第二个头学习的结果再去学习

    Efficientvit模型结构如下图所示:
    在这里插入图片描述

    a memory-efficient sandwich layout
    在这里插入图片描述
    在这里插入图片描述

    Cascaded Group Attention:解决了原来模型中多头重复学习(学习到的特征很多都是相似的)的问题,这里每个头学到的特征都不同,而且越往下面的头学到的特征越丰富。
    在这里插入图片描述

    Q是主动查询的行为,特征比K更加丰富,所以额外做了个Token Interation
    Q进行self-attention之前先通过多次分组卷积再一次学习
    Parameter Reallocation
    self-attention主要在进行QK,而且还需要对Q/K进行reshape,所以为了运算效率更快,Q与K的维度小一点
    而V只在后面被Q
    K得到的结果进行权重分配,没那么费劲,为了学习更多的特征,所以V维度更大一些

    Efficientvit变体模型结构如下表所示:
    在这里插入图片描述

    在YOLOv5项目中添加EfficientViT模型作为Backbone使用的教程:
    (1)将YOLOv5项目的models/yolo.py修改parse_model函数以及BaseModel的_forward_once函数

    def parse_model(d, ch):  # model_dict, input_channels(3)
        # Parse a YOLOv5 model.yaml dictionary
        LOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10}  {'module':<40}{'arguments':<30}")
        anchors, nc, gd, gw, act = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple'], d.get('activation')
        if act:
            Conv.default_act = eval(act)  # redefine default activation, i.e. Conv.default_act = nn.SiLU()
            LOGGER.info(f"{colorstr('activation:')} {act}")  # print
        na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors  # number of anchors
        no = na * (nc + 5)  # number of outputs = anchors * (classes + 5)
    
        # ---------------------------------------------------------------------------------------------------
        is_backbone = False
        layers, save, c2 = [], [], ch[-1]  # layers, savelist, ch out
        for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']):  # from, number, module, args
            try:
                t = m
                m = eval(m) if isinstance(m, str) else m  # eval strings
            except:
                pass
            for j, a in enumerate(args):
                with contextlib.suppress(NameError):
                    try:
                        args[j] = eval(a) if isinstance(a, str) else a  # eval strings
                    except:
                        args[j] = a
    
            n = n_ = max(round(n * gd), 1) if n > 1 else n  # depth gain
            if m in {
                    Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,
                    BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x}:
                c1, c2 = ch[f], args[0]
                if c2 != no:  # if not output
                    c2 = make_divisible(c2 * gw, 8)
    
                args = [c1, c2, *args[1:]]
                if m in {BottleneckCSP, C3, C3TR, C3Ghost, C3x}:
                    args.insert(2, n)  # number of repeats
                    n = 1
            elif m is nn.BatchNorm2d:
                args = [ch[f]]
            elif m is Concat:
                c2 = sum(ch[x] for x in f)
            # TODO: channel, gw, gd
            elif m in {Detect, Segment}:
                args.append([ch[x] for x in f])
                if isinstance(args[1], int):  # number of anchors
                    args[1] = [list(range(args[1] * 2))] * len(f)
                if m is Segment:
                    args[3] = make_divisible(args[3] * gw, 8)
            elif m is Contract:
                c2 = ch[f] * args[0] ** 2
            elif m is Expand:
                c2 = ch[f] // args[0] ** 2
            # -------------------------------------------------------------------------------------
            elif m in {}:
                m = m(*args)
                c2 = m.channel
            # -------------------------------------------------------------------------------------
            else:
                c2 = ch[f]
    
            # -------------------------------------------------------------------------------------
            if isinstance(c2, list):
                is_backbone = True
                m_ = m
                m_.backbone = True
            else:
                m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args)  # module
                t = str(m)[8:-2].replace('__main__.', '')  # module type
            # -------------------------------------------------------------------------------------
    
            np = sum(x.numel() for x in m_.parameters())  # number params
            # -------------------------------------------------------------------------------------
            # m_.i, m_.f, m_.type, m_.np = i, f, t, np  # attach index, 'from' index, type, number params
            m_.i, m_.f, m_.type, m_.np = i + 4 if is_backbone else i, f, t, np  # attach index, 'from' index, type, number params
            # -------------------------------------------------------------------------------------
            
            LOGGER.info(f'{i:>3}{str(f):>18}{n_:>3}{np:10.0f}  {t:<40}{str(args):<30}')  # print
            save.extend(x % (i + 4 if is_backbone else i) for x in ([f] if isinstance(f, int) else f) if x != -1)  # append to savelist
            # save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1)  # append to savelist
            layers.append(m_)
            if i == 0:
                ch = []
            
            # -------------------------------------------------------------------------------------
            if isinstance(c2, list):
                ch.extend(c2)
                for _ in range(5 - len(ch)):
                    ch.insert(0, 0)
            else:
                ch.append(c2)
            # -------------------------------------------------------------------------------------
    
        return nn.Sequential(*layers), sorted(save)
    
    
    def _forward_once(self, x, profile=False, visualize=False):
            y, dt = [], []  # outputs
            for m in self.model:
                if m.f != -1:  # if not from previous layer
                    x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]  # from earlier layers
                if profile:
                    self._profile_one_layer(m, x, dt)
                if hasattr(m, 'backbone'):
                    x = m(x)
                    for _ in range(5 - len(x)):
                        x.insert(0, None)
                    for i_idx, i in enumerate(x):
                        if i_idx in self.save:
                            y.append(i)
                        else:
                            y.append(None)
                    x = x[-1]
                else:
                    x = m(x)  # run
                    y.append(x if m.i in self.save else None)  # save output
                if visualize:
                    feature_visualization(x, m.type, m.i, save_dir=visualize)
            return x
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119

    (2)在models/backbone(新建)文件下新建EfficientViT.py,添加如下的代码:

    # --------------------------------------------------------
    # EfficientViT Model Architecture for Downstream Tasks
    # Copyright (c) 2022 Microsoft
    # Written by: Xinyu Liu
    # --------------------------------------------------------
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import torch.utils.checkpoint as checkpoint
    import itertools
    
    from timm.models.layers import SqueezeExcite
    
    import numpy as np
    import itertools
    
    __all__ = ['EfficientViT_M0', 'EfficientViT_M1', 'EfficientViT_M2', 'EfficientViT_M3', 'EfficientViT_M4', 'EfficientViT_M5']
    
    class Conv2d_BN(torch.nn.Sequential):
        def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1,
                     groups=1, bn_weight_init=1, resolution=-10000):
            super().__init__()
            self.add_module('c', torch.nn.Conv2d(
                a, b, ks, stride, pad, dilation, groups, bias=False))
            self.add_module('bn', torch.nn.BatchNorm2d(b))
            torch.nn.init.constant_(self.bn.weight, bn_weight_init)
            torch.nn.init.constant_(self.bn.bias, 0)
    
        @torch.no_grad()
        def fuse(self):
            c, bn = self._modules.values()
            w = bn.weight / (bn.running_var + bn.eps)**0.5
            w = c.weight * w[:, None, None, None]
            b = bn.bias - bn.running_mean * bn.weight / \
                (bn.running_var + bn.eps)**0.5
            m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size(
                0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups)
            m.weight.data.copy_(w)
            m.bias.data.copy_(b)
            return m
    
    def replace_batchnorm(net):
        for child_name, child in net.named_children():
            if hasattr(child, 'fuse'):
                setattr(net, child_name, child.fuse())
            elif isinstance(child, torch.nn.BatchNorm2d):
                setattr(net, child_name, torch.nn.Identity())
            else:
                replace_batchnorm(child)
                
    
    class PatchMerging(torch.nn.Module):
        def __init__(self, dim, out_dim, input_resolution):
            super().__init__()
            hid_dim = int(dim * 4)
            self.conv1 = Conv2d_BN(dim, hid_dim, 1, 1, 0, resolution=input_resolution)
            self.act = torch.nn.ReLU()
            self.conv2 = Conv2d_BN(hid_dim, hid_dim, 3, 2, 1, groups=hid_dim, resolution=input_resolution)
            self.se = SqueezeExcite(hid_dim, .25)
            self.conv3 = Conv2d_BN(hid_dim, out_dim, 1, 1, 0, resolution=input_resolution // 2)
    
        def forward(self, x):
            x = self.conv3(self.se(self.act(self.conv2(self.act(self.conv1(x))))))
            return x
    
    
    class Residual(torch.nn.Module):
        def __init__(self, m, drop=0.):
            super().__init__()
            self.m = m
            self.drop = drop
    
        def forward(self, x):
            if self.training and self.drop > 0:
                return x + self.m(x) * torch.rand(x.size(0), 1, 1, 1,
                                                  device=x.device).ge_(self.drop).div(1 - self.drop).detach()
            else:
                return x + self.m(x)
    
    
    class FFN(torch.nn.Module):
        def __init__(self, ed, h, resolution):
            super().__init__()
            self.pw1 = Conv2d_BN(ed, h, resolution=resolution)
            self.act = torch.nn.ReLU()
            self.pw2 = Conv2d_BN(h, ed, bn_weight_init=0, resolution=resolution)
    
        def forward(self, x):
            x = self.pw2(self.act(self.pw1(x)))
            return x
    
    
    class CascadedGroupAttention(torch.nn.Module):
        r""" Cascaded Group Attention.
    
        Args:
            dim (int): Number of input channels.
            key_dim (int): The dimension for query and key.
            num_heads (int): Number of attention heads.
            attn_ratio (int): Multiplier for the query dim for value dimension.
            resolution (int): Input resolution, correspond to the window size.
            kernels (List[int]): The kernel size of the dw conv on query.
        """
        def __init__(self, dim, key_dim, num_heads=8,
                     attn_ratio=4,
                     resolution=14,
                     kernels=[5, 5, 5, 5],):
            super().__init__()
            self.num_heads = num_heads
            self.scale = key_dim ** -0.5
            self.key_dim = key_dim
            self.d = int(attn_ratio * key_dim)
            self.attn_ratio = attn_ratio
    
            qkvs = []
            dws = []
            for i in range(num_heads):
                qkvs.append(Conv2d_BN(dim // (num_heads), self.key_dim * 2 + self.d, resolution=resolution))
                dws.append(Conv2d_BN(self.key_dim, self.key_dim, kernels[i], 1, kernels[i]//2, groups=self.key_dim, resolution=resolution))
            self.qkvs = torch.nn.ModuleList(qkvs)
            self.dws = torch.nn.ModuleList(dws)
            self.proj = torch.nn.Sequential(torch.nn.ReLU(), Conv2d_BN(
                self.d * num_heads, dim, bn_weight_init=0, resolution=resolution))
    
            points = list(itertools.product(range(resolution), range(resolution)))
            N = len(points)
            attention_offsets = {}
            idxs = []
            for p1 in points:
                for p2 in points:
                    offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
                    if offset not in attention_offsets:
                        attention_offsets[offset] = len(attention_offsets)
                    idxs.append(attention_offsets[offset])
            self.attention_biases = torch.nn.Parameter(
                torch.zeros(num_heads, len(attention_offsets)))
            self.register_buffer('attention_bias_idxs',
                                 torch.LongTensor(idxs).view(N, N))
    
        @torch.no_grad()
        def train(self, mode=True):
            super().train(mode)
            if mode and hasattr(self, 'ab'):
                del self.ab
            else:
                self.ab = self.attention_biases[:, self.attention_bias_idxs]
    
        def forward(self, x):  # x (B,C,H,W)
            B, C, H, W = x.shape
            trainingab = self.attention_biases[:, self.attention_bias_idxs]
            feats_in = x.chunk(len(self.qkvs), dim=1)
            feats_out = []
            feat = feats_in[0]
            for i, qkv in enumerate(self.qkvs):
                if i > 0: # add the previous output to the input
                    feat = feat + feats_in[i]
                feat = qkv(feat)
                q, k, v = feat.view(B, -1, H, W).split([self.key_dim, self.key_dim, self.d], dim=1) # B, C/h, H, W
                q = self.dws[i](q)
                q, k, v = q.flatten(2), k.flatten(2), v.flatten(2) # B, C/h, N
                attn = (
                    (q.transpose(-2, -1) @ k) * self.scale
                    +
                    (trainingab[i] if self.training else self.ab[i])
                )
                attn = attn.softmax(dim=-1) # BNN
                feat = (v @ attn.transpose(-2, -1)).view(B, self.d, H, W) # BCHW
                feats_out.append(feat)
            x = self.proj(torch.cat(feats_out, 1))
            return x
    
    
    class LocalWindowAttention(torch.nn.Module):
        r""" Local Window Attention.
    
        Args:
            dim (int): Number of input channels.
            key_dim (int): The dimension for query and key.
            num_heads (int): Number of attention heads.
            attn_ratio (int): Multiplier for the query dim for value dimension.
            resolution (int): Input resolution.
            window_resolution (int): Local window resolution.
            kernels (List[int]): The kernel size of the dw conv on query.
        """
        def __init__(self, dim, key_dim, num_heads=8,
                     attn_ratio=4,
                     resolution=14,
                     window_resolution=7,
                     kernels=[5, 5, 5, 5],):
            super().__init__()
            self.dim = dim
            self.num_heads = num_heads
            self.resolution = resolution
            assert window_resolution > 0, 'window_size must be greater than 0'
            self.window_resolution = window_resolution
            
            self.attn = CascadedGroupAttention(dim, key_dim, num_heads,
                                    attn_ratio=attn_ratio, 
                                    resolution=window_resolution,
                                    kernels=kernels,)
    
        def forward(self, x):
            B, C, H, W = x.shape
                   
            if H <= self.window_resolution and W <= self.window_resolution:
                x = self.attn(x)
            else:
                x = x.permute(0, 2, 3, 1)
                pad_b = (self.window_resolution - H %
                         self.window_resolution) % self.window_resolution
                pad_r = (self.window_resolution - W %
                         self.window_resolution) % self.window_resolution
                padding = pad_b > 0 or pad_r > 0
    
                if padding:
                    x = torch.nn.functional.pad(x, (0, 0, 0, pad_r, 0, pad_b))
    
                pH, pW = H + pad_b, W + pad_r
                nH = pH // self.window_resolution
                nW = pW // self.window_resolution
                # window partition, BHWC -> B(nHh)(nWw)C -> BnHnWhwC -> (BnHnW)hwC -> (BnHnW)Chw
                x = x.view(B, nH, self.window_resolution, nW, self.window_resolution, C).transpose(2, 3).reshape(
                    B * nH * nW, self.window_resolution, self.window_resolution, C
                ).permute(0, 3, 1, 2)
                x = self.attn(x)
                # window reverse, (BnHnW)Chw -> (BnHnW)hwC -> BnHnWhwC -> B(nHh)(nWw)C -> BHWC
                x = x.permute(0, 2, 3, 1).view(B, nH, nW, self.window_resolution, self.window_resolution,
                           C).transpose(2, 3).reshape(B, pH, pW, C)
    
                if padding:
                    x = x[:, :H, :W].contiguous()
    
                x = x.permute(0, 3, 1, 2)
    
            return x
    
    
    class EfficientViTBlock(torch.nn.Module):
        """ A basic EfficientViT building block.
    
        Args:
            type (str): Type for token mixer. Default: 's' for self-attention.
            ed (int): Number of input channels.
            kd (int): Dimension for query and key in the token mixer.
            nh (int): Number of attention heads.
            ar (int): Multiplier for the query dim for value dimension.
            resolution (int): Input resolution.
            window_resolution (int): Local window resolution.
            kernels (List[int]): The kernel size of the dw conv on query.
        """
        def __init__(self, type,
                     ed, kd, nh=8,
                     ar=4,
                     resolution=14,
                     window_resolution=7,
                     kernels=[5, 5, 5, 5],):
            super().__init__()
                
            self.dw0 = Residual(Conv2d_BN(ed, ed, 3, 1, 1, groups=ed, bn_weight_init=0., resolution=resolution))
            self.ffn0 = Residual(FFN(ed, int(ed * 2), resolution))
    
            if type == 's':
                self.mixer = Residual(LocalWindowAttention(ed, kd, nh, attn_ratio=ar, \
                        resolution=resolution, window_resolution=window_resolution, kernels=kernels))
                    
            self.dw1 = Residual(Conv2d_BN(ed, ed, 3, 1, 1, groups=ed, bn_weight_init=0., resolution=resolution))
            self.ffn1 = Residual(FFN(ed, int(ed * 2), resolution))
    
        def forward(self, x):
            return self.ffn1(self.dw1(self.mixer(self.ffn0(self.dw0(x)))))
    
    
    class EfficientViT(torch.nn.Module):
        def __init__(self, img_size=400,
                     patch_size=16,
                     frozen_stages=0,
                     in_chans=3,
                     stages=['s', 's', 's'],
                     embed_dim=[64, 128, 192],
                     key_dim=[16, 16, 16],
                     depth=[1, 2, 3],
                     num_heads=[4, 4, 4],
                     window_size=[7, 7, 7],
                     kernels=[5, 5, 5, 5],
                     down_ops=[['subsample', 2], ['subsample', 2], ['']],
                     pretrained=None,
                     distillation=False,):
            super().__init__()
    
            resolution = img_size
            self.patch_embed = torch.nn.Sequential(Conv2d_BN(in_chans, embed_dim[0] // 8, 3, 2, 1, resolution=resolution), torch.nn.ReLU(),
                               Conv2d_BN(embed_dim[0] // 8, embed_dim[0] // 4, 3, 2, 1, resolution=resolution // 2), torch.nn.ReLU(),
                               Conv2d_BN(embed_dim[0] // 4, embed_dim[0] // 2, 3, 2, 1, resolution=resolution // 4), torch.nn.ReLU(),
                               Conv2d_BN(embed_dim[0] // 2, embed_dim[0], 3, 1, 1, resolution=resolution // 8))
    
            resolution = img_size // patch_size
            attn_ratio = [embed_dim[i] / (key_dim[i] * num_heads[i]) for i in range(len(embed_dim))]
            self.blocks1 = []
            self.blocks2 = []
            self.blocks3 = []
            for i, (stg, ed, kd, dpth, nh, ar, wd, do) in enumerate(
                    zip(stages, embed_dim, key_dim, depth, num_heads, attn_ratio, window_size, down_ops)):
                for d in range(dpth):
                    eval('self.blocks' + str(i+1)).append(EfficientViTBlock(stg, ed, kd, nh, ar, resolution, wd, kernels))
                if do[0] == 'subsample':
                    #('Subsample' stride)
                    blk = eval('self.blocks' + str(i+2))
                    resolution_ = (resolution - 1) // do[1] + 1
                    blk.append(torch.nn.Sequential(Residual(Conv2d_BN(embed_dim[i], embed_dim[i], 3, 1, 1, groups=embed_dim[i], resolution=resolution)),
                                        Residual(FFN(embed_dim[i], int(embed_dim[i] * 2), resolution)),))
                    blk.append(PatchMerging(*embed_dim[i:i + 2], resolution))
                    resolution = resolution_
                    blk.append(torch.nn.Sequential(Residual(Conv2d_BN(embed_dim[i + 1], embed_dim[i + 1], 3, 1, 1, groups=embed_dim[i + 1], resolution=resolution)),
                                        Residual(FFN(embed_dim[i + 1], int(embed_dim[i + 1] * 2), resolution)),))
            self.blocks1 = torch.nn.Sequential(*self.blocks1)
            self.blocks2 = torch.nn.Sequential(*self.blocks2)
            self.blocks3 = torch.nn.Sequential(*self.blocks3)
            
            self.channel = [i.size(1) for i in self.forward(torch.randn(1, 3, 640, 640))]
    
        def forward(self, x):
            outs = []
            x = self.patch_embed(x)
            x = self.blocks1(x)
            outs.append(x)
            x = self.blocks2(x)
            outs.append(x)
            x = self.blocks3(x)
            outs.append(x)
            return outs
    
    EfficientViT_m0 = {
            'img_size': 224,
            'patch_size': 16,
            'embed_dim': [64, 128, 192],
            'depth': [1, 2, 3],
            'num_heads': [4, 4, 4],
            'window_size': [7, 7, 7],
            'kernels': [7, 5, 3, 3],
        }
    
    EfficientViT_m1 = {
            'img_size': 224,
            'patch_size': 16,
            'embed_dim': [128, 144, 192],
            'depth': [1, 2, 3],
            'num_heads': [2, 3, 3],
            'window_size': [7, 7, 7],
            'kernels': [7, 5, 3, 3],
        }
    
    EfficientViT_m2 = {
            'img_size': 224,
            'patch_size': 16,
            'embed_dim': [128, 192, 224],
            'depth': [1, 2, 3],
            'num_heads': [4, 3, 2],
            'window_size': [7, 7, 7],
            'kernels': [7, 5, 3, 3],
        }
    
    EfficientViT_m3 = {
            'img_size': 224,
            'patch_size': 16,
            'embed_dim': [128, 240, 320],
            'depth': [1, 2, 3],
            'num_heads': [4, 3, 4],
            'window_size': [7, 7, 7],
            'kernels': [5, 5, 5, 5],
        }
    
    EfficientViT_m4 = {
            'img_size': 224,
            'patch_size': 16,
            'embed_dim': [128, 256, 384],
            'depth': [1, 2, 3],
            'num_heads': [4, 4, 4],
            'window_size': [7, 7, 7],
            'kernels': [7, 5, 3, 3],
        }
    
    EfficientViT_m5 = {
            'img_size': 224,
            'patch_size': 16,
            'embed_dim': [192, 288, 384],
            'depth': [1, 3, 4],
            'num_heads': [3, 3, 4],
            'window_size': [7, 7, 7],
            'kernels': [7, 5, 3, 3],
        }
    
    def EfficientViT_M0(pretrained='', frozen_stages=0, distillation=False, fuse=False, pretrained_cfg=None, model_cfg=EfficientViT_m0):
        model = EfficientViT(frozen_stages=frozen_stages, distillation=distillation, pretrained=pretrained, **model_cfg)
        if pretrained:
            model.load_state_dict(update_weight(model.state_dict(), torch.load(pretrained)['model']))
        if fuse:
            replace_batchnorm(model)
        return model
    
    def EfficientViT_M1(pretrained='', frozen_stages=0, distillation=False, fuse=False, pretrained_cfg=None, model_cfg=EfficientViT_m1):
        model = EfficientViT(frozen_stages=frozen_stages, distillation=distillation, pretrained=pretrained, **model_cfg)
        if pretrained:
            model.load_state_dict(update_weight(model.state_dict(), torch.load(pretrained)['model']))
        if fuse:
            replace_batchnorm(model)
        return model
    
    def EfficientViT_M2(pretrained='', frozen_stages=0, distillation=False, fuse=False, pretrained_cfg=None, model_cfg=EfficientViT_m2):
        model = EfficientViT(frozen_stages=frozen_stages, distillation=distillation, pretrained=pretrained, **model_cfg)
        if pretrained:
            model.load_state_dict(update_weight(model.state_dict(), torch.load(pretrained)['model']))
        if fuse:
            replace_batchnorm(model)
        return model
    
    def EfficientViT_M3(pretrained='', frozen_stages=0, distillation=False, fuse=False, pretrained_cfg=None, model_cfg=EfficientViT_m3):
        model = EfficientViT(frozen_stages=frozen_stages, distillation=distillation, pretrained=pretrained, **model_cfg)
        if pretrained:
            model.load_state_dict(update_weight(model.state_dict(), torch.load(pretrained)['model']))
        if fuse:
            replace_batchnorm(model)
        return model
        
    def EfficientViT_M4(pretrained='', frozen_stages=0, distillation=False, fuse=False, pretrained_cfg=None, model_cfg=EfficientViT_m4):
        model = EfficientViT(frozen_stages=frozen_stages, distillation=distillation, pretrained=pretrained, **model_cfg)
        if pretrained:
            model.load_state_dict(update_weight(model.state_dict(), torch.load(pretrained)['model']))
        if fuse:
            replace_batchnorm(model)
        return model
    
    def EfficientViT_M5(pretrained='', frozen_stages=0, distillation=False, fuse=False, pretrained_cfg=None, model_cfg=EfficientViT_m5):
        model = EfficientViT(frozen_stages=frozen_stages, distillation=distillation, pretrained=pretrained, **model_cfg)
        if pretrained:
            model.load_state_dict(update_weight(model.state_dict(), torch.load(pretrained)['model']))
        if fuse:
            replace_batchnorm(model)
        return model
    
    def update_weight(model_dict, weight_dict):
        idx, temp_dict = 0, {}
        for k, v in weight_dict.items():
            # k = k[9:]
            if k in model_dict.keys() and np.shape(model_dict[k]) == np.shape(v):
                temp_dict[k] = v
                idx += 1
        model_dict.update(temp_dict)
        print(f'loading weights... {idx}/{len(model_dict)} items')
        return model_dict
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
    • 168
    • 169
    • 170
    • 171
    • 172
    • 173
    • 174
    • 175
    • 176
    • 177
    • 178
    • 179
    • 180
    • 181
    • 182
    • 183
    • 184
    • 185
    • 186
    • 187
    • 188
    • 189
    • 190
    • 191
    • 192
    • 193
    • 194
    • 195
    • 196
    • 197
    • 198
    • 199
    • 200
    • 201
    • 202
    • 203
    • 204
    • 205
    • 206
    • 207
    • 208
    • 209
    • 210
    • 211
    • 212
    • 213
    • 214
    • 215
    • 216
    • 217
    • 218
    • 219
    • 220
    • 221
    • 222
    • 223
    • 224
    • 225
    • 226
    • 227
    • 228
    • 229
    • 230
    • 231
    • 232
    • 233
    • 234
    • 235
    • 236
    • 237
    • 238
    • 239
    • 240
    • 241
    • 242
    • 243
    • 244
    • 245
    • 246
    • 247
    • 248
    • 249
    • 250
    • 251
    • 252
    • 253
    • 254
    • 255
    • 256
    • 257
    • 258
    • 259
    • 260
    • 261
    • 262
    • 263
    • 264
    • 265
    • 266
    • 267
    • 268
    • 269
    • 270
    • 271
    • 272
    • 273
    • 274
    • 275
    • 276
    • 277
    • 278
    • 279
    • 280
    • 281
    • 282
    • 283
    • 284
    • 285
    • 286
    • 287
    • 288
    • 289
    • 290
    • 291
    • 292
    • 293
    • 294
    • 295
    • 296
    • 297
    • 298
    • 299
    • 300
    • 301
    • 302
    • 303
    • 304
    • 305
    • 306
    • 307
    • 308
    • 309
    • 310
    • 311
    • 312
    • 313
    • 314
    • 315
    • 316
    • 317
    • 318
    • 319
    • 320
    • 321
    • 322
    • 323
    • 324
    • 325
    • 326
    • 327
    • 328
    • 329
    • 330
    • 331
    • 332
    • 333
    • 334
    • 335
    • 336
    • 337
    • 338
    • 339
    • 340
    • 341
    • 342
    • 343
    • 344
    • 345
    • 346
    • 347
    • 348
    • 349
    • 350
    • 351
    • 352
    • 353
    • 354
    • 355
    • 356
    • 357
    • 358
    • 359
    • 360
    • 361
    • 362
    • 363
    • 364
    • 365
    • 366
    • 367
    • 368
    • 369
    • 370
    • 371
    • 372
    • 373
    • 374
    • 375
    • 376
    • 377
    • 378
    • 379
    • 380
    • 381
    • 382
    • 383
    • 384
    • 385
    • 386
    • 387
    • 388
    • 389
    • 390
    • 391
    • 392
    • 393
    • 394
    • 395
    • 396
    • 397
    • 398
    • 399
    • 400
    • 401
    • 402
    • 403
    • 404
    • 405
    • 406
    • 407
    • 408
    • 409
    • 410
    • 411
    • 412
    • 413
    • 414
    • 415
    • 416
    • 417
    • 418
    • 419
    • 420
    • 421
    • 422
    • 423
    • 424
    • 425
    • 426
    • 427
    • 428
    • 429
    • 430
    • 431
    • 432
    • 433
    • 434
    • 435
    • 436
    • 437
    • 438
    • 439
    • 440
    • 441
    • 442
    • 443
    • 444
    • 445
    • 446
    • 447
    • 448
    • 449

    (3)在models/yolo.py导入EfficientViT模型并在parse_model函数中修改如下:

    
    from models.backbone.EfficientViT import *
    ---------------------------------------------------
    elif m in {EfficientViT_M0, EfficientViT_M1, EfficientViT_M2, EfficientViT_M3, EfficientViT_M4, EfficientViT_M5,}:
    m = m(*args)
    c2 = m.channel
    ---------------------------------------------------
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    (4)在model下面新建配置文件:yolov5-efficientvit.yaml

    
    # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
    # Parameters
    nc: 80  # number of classes
    depth_multiple: 0.33  # model depth multiple
    width_multiple: 0.25  # layer channel multiple
    anchors:
      - [10,13, 16,30, 33,23]  # P3/8
      - [30,61, 62,45, 59,119]  # P4/16
      - [116,90, 156,198, 373,326]  # P5/32
    
    # YOLOv5 v6.0 backbone
    backbone:
      # [from, number, module, args]
      [[-1, 1, EfficientViT_M0, []], # 4
       [-1, 1, SPPF, [1024, 5]],  # 5
      ]
    
    # YOLOv5 v6.0 head
    head:
      [[-1, 1, Conv, [512, 1, 1]], # 6
       [-1, 1, nn.Upsample, [None, 2, 'nearest']], # 7
       [[-1, 3], 1, Concat, [1]],  # cat backbone P4 8
       [-1, 3, C3, [512, False]],  # 9
    
       [-1, 1, Conv, [256, 1, 1]], # 10
       [-1, 1, nn.Upsample, [None, 2, 'nearest']], # 11
       [[-1, 2], 1, Concat, [1]],  # cat backbone P3 12
       [-1, 3, C3, [256, False]],  # 13 (P3/8-small)
    
       [-1, 1, Conv, [256, 3, 2]], # 14
       [[-1, 10], 1, Concat, [1]],  # cat head P4 15
       [-1, 3, C3, [512, False]],  # 16 (P4/16-medium)
    
       [-1, 1, Conv, [512, 3, 2]], # 17
       [[-1, 5], 1, Concat, [1]],  # cat head P5 18
       [-1, 3, C3, [1024, False]],  # 19 (P5/32-large)
    
       [[13, 16, 19], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)
      ]
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40

    (5)运行验证:在models/yolo.py文件指定–cfg参数为新建的yolov5-efficientvit.yaml

        from  n    params  module                                  arguments                     
      0                -1  1   2155680  EfficientViT_M0                         []                            
      1                -1  1    117440  models.common.SPPF                      [192, 256, 5]                 
      2                -1  1     33024  models.common.Conv                      [256, 128, 1, 1]              
      3                -1  1         0  torch.nn.modules.upsampling.Upsample    [None, 2, 'nearest']          
      4           [-1, 3]  1         0  models.common.Concat                    [1]                           
      5                -1  1     90880  models.common.C3                        [256, 128, 1, False]          
      6                -1  1      8320  models.common.Conv                      [128, 64, 1, 1]               
      7                -1  1         0  torch.nn.modules.upsampling.Upsample    [None, 2, 'nearest']          
      8           [-1, 2]  1         0  models.common.Concat                    [1]                           
      9                -1  1     22912  models.common.C3                        [128, 64, 1, False]           
     10                -1  1     36992  models.common.Conv                      [64, 64, 3, 2]                
     11          [-1, 10]  1         0  models.common.Concat                    [1]                           
     12                -1  1     74496  models.common.C3                        [128, 128, 1, False]          
     13                -1  1    147712  models.common.Conv                      [128, 128, 3, 2]              
     14           [-1, 5]  1         0  models.common.Concat                    [1]                           
     15                -1  1    329216  models.common.C3                        [384, 256, 1, False]          
     16      [13, 16, 19]  1    115005  Detect                                  [80, [[10, 13, 16, 30, 33, 23], [30, 61, 62, 45, 59, 119], [116, 90, 156, 198, 373, 326]], [64, 128, 256]]
    YOLOv5-efficientvit summary: 582 layers, 3131677 parameters, 3131677 gradients
    Fusing layers... 
    YOLOv5-efficientvit summary: 556 layers, 3129213 parameters, 3129213 gradients
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21

    目前整个项目计划更新至少有50+Vision Transformer Backbone,以及一些其他的改进策略,另外后续也会同步更新改进后的模型在MS COCO数据集上从零开始训练得到的模型权重和训练结果。想要了解项目的朋友私信博主或关注gzh:BestSongC 发送yolo改进即可获取项目信息。

  • 相关阅读:
    AppGallery Connect场景化开发实战—图片存储分享
    电脑出现关于kernelbase.dll文件找不到的情况,有什么办法可以解决?
    iceoryx源码阅读(八)——IPC通信机制
    基于Pyflwdir实现流域的提取(参照官网例子)
    NLP基本业务范围
    会计学基础重点
    折线图geom_line()参数选项
    Mysql各种锁
    wifi码系统贴牌源码定制开发搭建oem
    git stash详解
  • 原文地址:https://blog.csdn.net/sc1434404661/article/details/136287484