• EfficientViT:高分辨率密集预测的多尺度线性关注


    标题:EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction

    论文:https://arxiv.org/abs/2205.14756

    中文版【读点论文】EfficientViT: Enhanced Linear Attention for High-Resolution Low-Computation将softmax注意力转变为线性注意力_羞儿的博客-CSDN博客

    代码:https://codeload.github.com/mit-han-lab/efficientvit/zip/refs/heads/master

    目录

    一、摘要

    二、主要贡献

    三、方法论

    3.1 Multi-Scale Linear Attention(多尺度线性注意力) 

    3.2 EfficientViT架构

    四、实验

    4.1 消融研究

    4.2 语义分割实验

    五、总结


    一、摘要

    研究背景高分辨率密集预测使许多有吸引力的现实世界的应用,如计算摄影,自动驾驶等,然而,巨大的计算成本使得部署最先进的高分辨率密集预测模型的硬件设备上的困难

    主要工作:本文提出了一种新的多尺度线性attention的高分辨率视觉模型——EfficientViT。与之前的高分辨率密集预测模型依赖于大量的softmax关注、硬件低效的大核卷积或复杂的拓扑结构来获得良好的性能不同,多尺度线性attention只需要轻量级和硬件高效的操作就能实现全局接受场和多尺度学习(高分辨率密集预测的两个理想特征)。

    研究成果:因此,在各种硬件平台(包括移动CPU、边缘GPU和云GPU)上,EfficientViT比以前的最先进型号提供了显著的性能提升。在Cityscapes(数据集)上没有性能损失的情况下,EfficientViT分别比SegFormer和SegNeXt提供了高达13.9倍和6.2倍的GPU延迟减少。对于超分辨率,EfficientViT比Restormer提供高达6.4倍的加速,同时提供0.11dB的PSNR增益。

    二、主要贡献

    1. 引入了一个新的多尺度线性注意力模块,用于高效的高分辨率稠密预测。它实现了全局感受野多尺度学习同时保持了良好的硬件效率。据我们所知,我们的工作是第一个证明线性注意力对高分辨率密集预测的有效性

    2. 我们设计了高效vit,一个新的高分辨率系列基于视觉模型,提出了多尺度线性注意模块

    3. EfficientViT在不同硬件平台(移动的CPU,边缘GPU和云GPU)上的语义分割,超分辨率,分割任何东西和ImageNet分类方面都比以前的SOTA模型有显著的加速。

    三、方法论

    3.1 Multi-Scale Linear Attention(多尺度线性注意力) 

    多尺度线性注意力仅通过硬件高效的操作同时实现了全局感受野和多尺度学习。基于多尺度线性注意力,作者提出了一种新的用于高分辨率密集预测的Vision transformer模型EfficientVit。  

    动机:从性能角度来看,全局感受野和多尺度学习是必不可少的。以前的 SOTA 高分辨率密集预测模型通过启用这些特征提供了较强的性能,但不能提供良好的效率。多尺度线性注意力模块通过用轻微的性能损失换取显著的效率提升来解决这个问题。

    方法使用ReLU线性注意力来实现全局感受野,而不是繁重的softmax注意力。

    ReLU线性注意力的公式推导

    由传统的softmax注意力公式和Relu注意力相似度计算函数(相似度计算函数替换为Relu版的),可得:

    由矩阵乘法的结合律,可得:

    推导最终结论:由公式(3)所示,只需要计算\in \mathbb{R}^{d\times1}一次,就可以对每个Query重用它们(多头attention机制查询无关问题的最终解???),从而只需要O(N)的计算代价和O(N)的内存。 

      

    ReLU线性注意力的局限性:如下图所示,softmax 注意和 ReLU 线性注意的注意图。由于缺乏非线性相似函数,ReLU 线性注意不能生成集中的注意图,捕获局部信息的能力较弱。(ReLU线性注意力缺点暴露)

    解决方案:

    1. 为了减轻其局限性,我们提出用卷积增强 ReLU 线性注意力。具体来说,在每个 FFN 层中插入深度卷积。如下图所示,其中 ReLU 线性注意力捕获上下文信息,FFN+DWConv 捕获局部信息

    2. 将邻近的 Q/K/V token信息聚合(拼接)成多尺度token以增强 ReLU 线性注意的多尺度学习能力这里多尺度是指通道方向上的不同尺度,所以聚合能多尺度学习能力)。

    具体来说,将所有DWConv融合成单个DWConv组,将所有 1x1 Convs合并成单个1x1的卷积组,组数为3 × #head,每组通道数为d。得到多尺度token后,对其进行ReLU线性注意力,提取多尺度全局特征。最后,将特征沿头部维度进行连接,并将其提供给最终的线性层以融合特征。

    (本质上是使用nn.Conv2d()函数中的groups参数,将输入和输出通道分成几组进行卷积操作,学习通道方向上的不同尺度的信息。)

    Q:感受野和注意力机制有什么关系?

    A:注意力机制可以通过计算不同位置之间的关系,来捕捉长距离依赖关系,从而扩大感受野,提高网络的感知能力。

    代码如下

    轻量权重多尺度注意力模块

    1. # 轻量权重多尺度注意力
    2. class LiteMLA(nn.Module):
    3. r"""Lightweight multi-scale linear attention"""
    4. def __init__(
    5. self,
    6. in_channels: int,
    7. out_channels: int,
    8. heads: int or None = None,
    9. heads_ratio: float = 1.0,
    10. dim=8,
    11. use_bias=False,
    12. norm=(None, "bn2d"),
    13. act_func=(None, None),
    14. kernel_func="relu",
    15. scales: tuple[int, ...] = (5,),
    16. eps=1.0e-15,
    17. ):
    18. super(LiteMLA, self).__init__()
    19. self.eps = eps
    20. heads = heads or int(in_channels // dim * heads_ratio)
    21. total_dim = heads * dim
    22. use_bias = val2tuple(use_bias, 2)
    23. norm = val2tuple(norm, 2)
    24. act_func = val2tuple(act_func, 2)
    25. self.dim = dim
    26. self.qkv = ConvLayer(
    27. in_channels,
    28. 3 * total_dim,
    29. 1,
    30. use_bias=use_bias[0],
    31. norm=norm[0],
    32. act_func=act_func[0],
    33. )
    34. self.aggreg = nn.ModuleList(
    35. [
    36. nn.Sequential(
    37. nn.Conv2d(
    38. 3 * total_dim,
    39. 3 * total_dim,
    40. scale,
    41. padding=get_same_padding(scale),
    42. groups=3 * total_dim,
    43. bias=use_bias[0],
    44. ),
    45. nn.Conv2d(3 * total_dim, 3 * total_dim, 1, groups=3 * heads, bias=use_bias[0]),
    46. )
    47. for scale in scales
    48. ]
    49. ) # nn.Conv2d()函数中的groups参数是指将输入和输出通道分成几组进行卷积操作
    50. self.kernel_func = build_act(kernel_func, inplace=False) # Relu激活函数
    51. self.proj = ConvLayer(
    52. total_dim * (1 + len(scales)),
    53. out_channels,
    54. 1,
    55. use_bias=use_bias[1],
    56. norm=norm[1],
    57. act_func=act_func[1],
    58. )
    59. @autocast(enabled=False)
    60. def relu_linear_att(self, qkv: torch.Tensor) -> torch.Tensor:
    61. B, _, H, W = list(qkv.size())
    62. if qkv.dtype == torch.float16:
    63. qkv = qkv.float()
    64. qkv = torch.reshape(
    65. qkv,
    66. (
    67. B,
    68. -1,
    69. 3 * self.dim,
    70. H * W,
    71. ),
    72. )
    73. qkv = torch.transpose(qkv, -1, -2)
    74. q, k, v = (
    75. qkv[..., 0 : self.dim],
    76. qkv[..., self.dim : 2 * self.dim],
    77. qkv[..., 2 * self.dim :],
    78. )
    79. # lightweight linear attention
    80. q = self.kernel_func(q) # 进行relu激活
    81. k = self.kernel_func(k) # 进行relu激活
    82. # linear matmul
    83. trans_k = k.transpose(-1, -2)
    84. v = F.pad(v, (0, 1), mode="constant", value=1) # 进行维度扩展
    85. kv = torch.matmul(trans_k, v) # 按推导公式计算
    86. out = torch.matmul(q, kv)
    87. out = out[..., :-1] / (out[..., -1:] + self.eps)
    88. out = torch.transpose(out, -1, -2)
    89. out = torch.reshape(out, (B, -1, H, W))
    90. return out
    91. def forward(self, x: torch.Tensor) -> torch.Tensor:
    92. # generate multi-scale q, k, v
    93. qkv = self.qkv(x) # 获取Q、K、V,由1x1卷积得到
    94. multi_scale_qkv = [qkv]
    95. for op in self.aggreg: # 卷积聚合,学习通道上的多尺度信息
    96. multi_scale_qkv.append(op(qkv))
    97. multi_scale_qkv = torch.cat(multi_scale_qkv, dim=1) # Q、K、V拼接
    98. out = self.relu_linear_att(multi_scale_qkv) # 重新等分划分为Q,K,V,馈入ReLU线性注意力
    99. out = self.proj(out) # 1x1卷积输出,模拟线性层
    100. return out

    3.2 EfficientViT架构

    如上图所示,

    Backbone(骨干):由输入层和四个阶段组成,特征图大小逐渐减小,通道数量逐渐增加。在阶段3和4中插入EfficientViT模块。对于下采样,我们使用步幅为2的MBConv。

    Head(分割头):P2、P3和P4表示阶段2、3和4的输出,形成特征图的金字塔。为了简单和高效,使用1x 1卷积和标准上采样操作(例如,双线性/双三次上采样)以匹配它们的空间和信道大小并经由加法来融合它们。简单的头部设计,其包括若干MBConv块和输出层(即,预测和上采样)。

      

    代码如下

    Backbone(骨干)

    1. class EfficientViTBackbone(nn.Module):
    2. # Backbone:input_stem + stage1 + stage2 + stage3 + stage4
    3. def __init__(
    4. self,
    5. width_list: list[int],
    6. depth_list: list[int],
    7. in_channels=3,
    8. dim=32,
    9. expand_ratio=4,
    10. norm="bn2d",
    11. act_func="hswish",
    12. ) -> None:
    13. super().__init__()
    14. self.width_list = []
    15. # input stem
    16. self.input_stem = [
    17. ConvLayer(
    18. in_channels=3,
    19. out_channels=width_list[0],
    20. stride=2,
    21. norm=norm,
    22. act_func=act_func,
    23. ) # 3x3卷积 -> 下采2倍
    24. ]
    25. for _ in range(depth_list[0]):
    26. block = self.build_local_block( # 构建DSConv模块,捕捉局部信息
    27. in_channels=width_list[0],
    28. out_channels=width_list[0],
    29. stride=1,
    30. expand_ratio=1,
    31. norm=norm,
    32. act_func=act_func,
    33. )
    34. self.input_stem.append(ResidualBlock(block, IdentityLayer())) # 增加残差
    35. in_channels = width_list[0]
    36. self.input_stem = OpSequential(self.input_stem) # 把input_stem阶段各模块按顺序添加到ModuleList中
    37. self.width_list.append(in_channels) # 把每个模块的通道数添加到width_list
    38. # stages
    39. self.stages = []
    40. # # # stages1
    41. for w, d in zip(width_list[1:3], depth_list[1:3]):
    42. stage = []
    43. for i in range(d):
    44. stride = 2 if i == 0 else 1
    45. block = self.build_local_block( # 构建MBConv模块,捕捉局部信息
    46. in_channels=in_channels,
    47. out_channels=w,
    48. stride=stride,
    49. expand_ratio=expand_ratio,
    50. norm=norm,
    51. act_func=act_func,
    52. )
    53. block = ResidualBlock(block, IdentityLayer() if stride == 1 else None) # 增加残差
    54. stage.append(block)
    55. in_channels = w
    56. self.stages.append(OpSequential(stage))
    57. self.width_list.append(in_channels)
    58. for w, d in zip(width_list[3:], depth_list[3:]):
    59. stage = []
    60. # # # stages2
    61. block = self.build_local_block( # 构建MBConv模块,捕捉局部信息
    62. in_channels=in_channels,
    63. out_channels=w,
    64. stride=2,
    65. expand_ratio=expand_ratio,
    66. norm=norm,
    67. act_func=act_func,
    68. fewer_norm=True,
    69. )
    70. stage.append(ResidualBlock(block, None))
    71. in_channels = w
    72. # # # stages3、4
    73. for _ in range(d):
    74. stage.append(
    75. EfficientViTBlock( # EfficientViTBlock模块,多尺度注意力提取上下文特征
    76. in_channels=in_channels,
    77. dim=dim,
    78. expand_ratio=expand_ratio,
    79. norm=norm,
    80. act_func=act_func,
    81. )
    82. )
    83. self.stages.append(OpSequential(stage))
    84. self.width_list.append(in_channels)
    85. self.stages = nn.ModuleList(self.stages) # nn.ModuleList,用于存储不同的模块,并自动将每个模块的参数添加到网络中
    86. # 构建DSConv 或 MBConv —> 局部信息
    87. @staticmethod
    88. def build_local_block(
    89. in_channels: int,
    90. out_channels: int,
    91. stride: int,
    92. expand_ratio: float,
    93. norm: str,
    94. act_func: str,
    95. fewer_norm: bool = False,
    96. ) -> nn.Module:
    97. if expand_ratio == 1:
    98. block = DSConv( # DSConv模块
    99. in_channels=in_channels,
    100. out_channels=out_channels,
    101. stride=stride,
    102. use_bias=(True, False) if fewer_norm else False,
    103. norm=(None, norm) if fewer_norm else norm,
    104. act_func=(act_func, None),
    105. )
    106. else:
    107. block = MBConv( # MBConv模块,Mobile倒置残差瓶颈卷积 -> 2倍下采样
    108. in_channels=in_channels,
    109. out_channels=out_channels,
    110. stride=stride,
    111. expand_ratio=expand_ratio,
    112. use_bias=(True, True, False) if fewer_norm else False,
    113. norm=(None, None, norm) if fewer_norm else norm,
    114. act_func=(act_func, act_func, None),
    115. )
    116. return block
    117. def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
    118. output_dict = {"input": x}
    119. output_dict["stage0"] = x = self.input_stem(x)
    120. for stage_id, stage in enumerate(self.stages, 1): # 网络的backbone
    121. output_dict["stage%d" % stage_id] = x = stage(x)
    122. output_dict["stage_final"] = x
    123. return output_dict

    DSConv模块

    1. class DSConv(nn.Module):
    2. def __init__(
    3. self,
    4. in_channels: int,
    5. out_channels: int,
    6. kernel_size=3,
    7. stride=1,
    8. use_bias=False,
    9. norm=("bn2d", "bn2d"),
    10. act_func=("relu6", None),
    11. ):
    12. super(DSConv, self).__init__()
    13. use_bias = val2tuple(use_bias, 2)
    14. norm = val2tuple(norm, 2)
    15. act_func = val2tuple(act_func, 2)
    16. self.depth_conv = ConvLayer(
    17. in_channels,
    18. in_channels,
    19. kernel_size,
    20. stride,
    21. groups=in_channels,
    22. norm=norm[0],
    23. act_func=act_func[0],
    24. use_bias=use_bias[0],
    25. )
    26. self.point_conv = ConvLayer(
    27. in_channels,
    28. out_channels,
    29. 1,
    30. norm=norm[1],
    31. act_func=act_func[1],
    32. use_bias=use_bias[1],
    33. )
    34. def forward(self, x: torch.Tensor) -> torch.Tensor:
    35. x = self.depth_conv(x)
    36. x = self.point_conv(x)
    37. return x

    MBConv模块

    1. # MBConv
    2. class MBConv(nn.Module):
    3. def __init__(
    4. self,
    5. in_channels: int,
    6. out_channels: int,
    7. kernel_size=3,
    8. stride=1,
    9. mid_channels=None,
    10. expand_ratio=6,
    11. use_bias=False,
    12. norm=("bn2d", "bn2d", "bn2d"),
    13. act_func=("relu6", "relu6", None),
    14. ):
    15. super(MBConv, self).__init__()
    16. use_bias = val2tuple(use_bias, 3)
    17. norm = val2tuple(norm, 3)
    18. act_func = val2tuple(act_func, 3)
    19. mid_channels = mid_channels or round(in_channels * expand_ratio)
    20. self.inverted_conv = ConvLayer(
    21. in_channels,
    22. mid_channels,
    23. 1,
    24. stride=1,
    25. norm=norm[0],
    26. act_func=act_func[0],
    27. use_bias=use_bias[0],
    28. )
    29. self.depth_conv = ConvLayer(
    30. mid_channels,
    31. mid_channels,
    32. kernel_size,
    33. stride=stride,
    34. groups=mid_channels,
    35. norm=norm[1],
    36. act_func=act_func[1],
    37. use_bias=use_bias[1],
    38. )
    39. self.point_conv = ConvLayer(
    40. mid_channels,
    41. out_channels,
    42. 1,
    43. norm=norm[2],
    44. act_func=act_func[2],
    45. use_bias=use_bias[2],
    46. )
    47. def forward(self, x: torch.Tensor) -> torch.Tensor:
    48. x = self.inverted_conv(x) # 512
    49. x = self.depth_conv(x) # 512
    50. x = self.point_conv(x) # 256
    51. return x

     EfficientViTBlock模块

    1. # EfficientViTBlock模块 —> 提取上下文特征
    2. class EfficientViTBlock(nn.Module):
    3. def __init__(
    4. self,
    5. in_channels: int,
    6. heads_ratio: float = 1.0,
    7. dim=32,
    8. expand_ratio: float = 4,
    9. norm="bn2d",
    10. act_func="hswish",
    11. ):
    12. super(EfficientViTBlock, self).__init__()
    13. self.context_module = ResidualBlock(
    14. LiteMLA( # 轻量权重多尺度注意力
    15. in_channels=in_channels,
    16. out_channels=in_channels,
    17. heads_ratio=heads_ratio,
    18. dim=dim,
    19. norm=(None, norm),
    20. ),
    21. IdentityLayer(),
    22. )
    23. local_module = MBConv(
    24. in_channels=in_channels,
    25. out_channels=in_channels,
    26. expand_ratio=expand_ratio,
    27. use_bias=(True, True, False),
    28. norm=(None, None, norm),
    29. act_func=(act_func, act_func, None),
    30. )
    31. self.local_module = ResidualBlock(local_module, IdentityLayer()) # 添加残差连接
    32. def forward(self, x: torch.Tensor) -> torch.Tensor:
    33. x = self.context_module(x) # 轻量多尺度注意力 -> 全局上下文特征
    34. x = self.local_module(x) # 深度卷积 -> 局部特征
    35. return x

      

    四、实验

    数据集:Cityscapes 和 ADE20K数据集。

    评价指标:mIoU、Params和MAC(乘加累积操作数)。

    4.1 消融研究

    (1)EfficientViT模块的性能测试

    mIoU和MAC在Cityscapes上测量,输入分辨率为1024x2048。重新调整模型的宽度,使它们具有相同的MAC,由上表所示,多尺度学习和全局感受野对于获得良好的语义分割性能至关重要。

    (2)ImageNet上的主干性能对比

    EfficientViT-L2-r384在ImageNet上获得了86.0的top-1精度,比EfficientNetV 2-L提供了+0.3的精度增益,在A100 GPU上提供了2.6倍的加速。

    4.2 语义分割实验

    与先进语义分割模型在Cityscapes数据集上的对比。

    与SegFormer相比,EfficientViT在mIoU更高的边缘GPU(Jetson AGX Orin)上获得了高达13倍的MAC数节省和高达8.8倍的延迟减少。与SegNeXt相比,EfficientViT在边缘GPU上提供高达2.0倍的MAC减少和3.8倍的加速,同时保持更高的mIoU。 

    五、总结

    1. 本文针对高分辨率稠密预测的有效架构设计,引入了一个轻量级的多尺度注意力模块,它同时实现了全局感受野,以及具有轻量级和硬件高效操作的多尺度学习,从而在各种硬件设备上提供了显着的加速,而不会比SOTA高分辨率密集预测模型带来性能损失。

    2. 多尺度线性注意力,使用ReLU线性注意力来实现全局感受野,通过FFN+DWConv 捕获局部信息和卷积聚合捕获多尺度信息,以此克服ReLU线性注意力轻量化所带来的缺点。

  • 相关阅读:
    java计算机毕业设计springboot+vue留学服务管理平台系统(源码+系统+mysql数据库+Lw文档)
    基于多领导者智能体的Olfati算法matlab仿真
    Linux编辑器-gcc/g++使用
    AJAX基础
    PYTHON蓝桥杯——每日一练(简单题)
    大数据因果推理与学习入门综合概述
    【Linux】字节序理解
    揭露测试外包公司,关于外包,你或许听到过这样的声音
    Java秒杀系统方案优化
    Pandas - 数据转换
  • 原文地址:https://blog.csdn.net/qq_45981086/article/details/134088837