• YOLOv7改进:GAMAttention注意力机制


    1.背景介绍
    为了提高各种计算机视觉任务的性能,人们研究了各种注意机制。然而,以往的方法忽略了保留通道和空间方面的信息以增强跨维度交互的重要性。因此,我们提出了一种全局调度机制,通过减少信息缩减和放大全局交互表示来提高深度神经网络的性能。我们沿着卷积空间注意子模块引入了用于通道注意的多层感知器3D置换。

    论文题目:Global Attention Mechanism: Retain Information to Enhance Channel-Spatial Interactions
    论文地址:https://paperswithcode.com/paper/global-attention-mechanism-retain-information

    GAMAttention注意力机制原理图

    对于ImageNet-1K,我们将图像预处理为224×224(He et al.[2016])。我们包括ResNet18和ResNet50(He et al.[2016]),以验证不同网络深度的方法推广。对于ResNet50,我们将其与群卷积进行了比较,以防止参数显著增加。我们将起始学习率设置为0.1,并每隔30个阶段降低一次。我们总共使用90个训练时段。在空间注意子模块中,我们将第一个块的第一步从1切换到2,以匹配特征的大小。为了进行公平比较,CBAM保留了其他设置,包括在空间注意子模块中使用最大池。3 MobileNet V2是用于图像分类的最高效的轻量级模型之一。我们对MobileNet V2使用相同的ResNet设置,只是使用了0.045的初始学习率和4×10的权重衰减−5.对ImageNet-1K的评估如表所示。它表明GAM可以稳定地提高不同神经架构的性能。尤其是对于ResNet18,GAM以更少的参数和更好的效率优于ABN。

    相关实验结果

    对ImageNet-1K的评估如表2所示,它表明GAM可以稳定地提高不同神经体系结构的性能。特别是,对于ResNet18,GAM的性能优于ABN,参数更少,效率更高。

     为了更好地理解空间注意和通道注意分别对消融的贡献,我们通过开启和关闭一种方式进行了消融研究。例如,ch表示空间注意力被关闭,而频道注意力被打开。SP表示通道关注已关闭,空间关注已打开。结果如表3所示。我们可以在两个开关实验中观察到性能的提高。结果表明,空间关注度和通道关注度对性能增益均有贡献。请注意,它们的组合进一步提高了性能。

     将GAM与CBAM在使用和不使用ResNet18最大池化的情况下进行比较。表4显示了结果。可以观察到,在这两种情况下,我们的方法都优于CBAM。

    2.YOLOv7改进方法

    2.1增加以下GAMAttention.yaml文件

    1. # YOLOv7 🚀, GPL-3.0 license
    2. # parameters
    3. nc: 80 # number of classes
    4. depth_multiple: 0.33 # model depth multiple
    5. width_multiple: 1.0 # layer channel multiple
    6. # anchors
    7. anchors:
    8. - [12,16, 19,36, 40,28] # P3/8
    9. - [36,75, 76,55, 72,146] # P4/16
    10. - [142,110, 192,243, 459,401] # P5/32
    11. # yolov7 backbone by yoloair
    12. backbone:
    13. # [from, number, module, args]
    14. [[-1, 1, Conv, [32, 3, 1]], # 0
    15. [-1, 1, Conv, [64, 3, 2]], # 1-P1/2
    16. [-1, 1, Conv, [64, 3, 1]],
    17. [-1, 1, Conv, [128, 3, 2]], # 3-P2/4
    18. [-1, 1, CNeB, [128]],
    19. [-1, 1, Conv, [256, 3, 2]],
    20. [-1, 1, MP, []],
    21. [-1, 1, Conv, [128, 1, 1]],
    22. [-3, 1, Conv, [128, 1, 1]],
    23. [-1, 1, Conv, [128, 3, 2]],
    24. [[-1, -3], 1, Concat, [1]], # 16-P3/8
    25. [-1, 1, Conv, [128, 1, 1]],
    26. [-2, 1, Conv, [128, 1, 1]],
    27. [-1, 1, Conv, [128, 3, 1]],
    28. [-1, 1, Conv, [128, 3, 1]],
    29. [-1, 1, Conv, [128, 3, 1]],
    30. [-1, 1, Conv, [128, 3, 1]],
    31. [[-1, -3, -5, -6], 1, Concat, [1]],
    32. [-1, 1, Conv, [512, 1, 1]],
    33. [-1, 1, MP, []],
    34. [-1, 1, Conv, [256, 1, 1]],
    35. [-3, 1, Conv, [256, 1, 1]],
    36. [-1, 1, Conv, [256, 3, 2]],
    37. [[-1, -3], 1, Concat, [1]],
    38. [-1, 1, Conv, [256, 1, 1]],
    39. [-2, 1, Conv, [256, 1, 1]],
    40. [-1, 1, Conv, [256, 3, 1]],
    41. [-1, 1, Conv, [256, 3, 1]],
    42. [-1, 1, Conv, [256, 3, 1]],
    43. [-1, 1, Conv, [256, 3, 1]],
    44. [[-1, -3, -5, -6], 1, Concat, [1]],
    45. [-1, 1, Conv, [1024, 1, 1]],
    46. [-1, 1, MP, []],
    47. [-1, 1, Conv, [512, 1, 1]],
    48. [-3, 1, Conv, [512, 1, 1]],
    49. [-1, 1, Conv, [512, 3, 2]],
    50. [[-1, -3], 1, Concat, [1]],
    51. [-1, 1, CNeB, [1024]],
    52. [-1, 1, Conv, [256, 3, 1]],
    53. ]
    54. # yolov7 head by yoloair
    55. head:
    56. [[-1, 1, SPPCSPC, [512]],
    57. [-1, 1, Conv, [256, 1, 1]],
    58. [-1, 1, nn.Upsample, [None, 2, 'nearest']],
    59. [31, 1, Conv, [256, 1, 1]],
    60. [[-1, -2], 1, Concat, [1]],
    61. [-1, 1, C3C2, [128]],
    62. [-1, 1, Conv, [128, 1, 1]],
    63. [-1, 1, nn.Upsample, [None, 2, 'nearest']],
    64. [18, 1, Conv, [128, 1, 1]],
    65. [[-1, -2], 1, Concat, [1]],
    66. [-1, 1, C3C2, [128]],
    67. [-1, 1, MP, []],
    68. [-1, 1, Conv, [128, 1, 1]],
    69. [-3, 1, GAMAttention, [128]],
    70. [-1, 1, Conv, [128, 3, 2]],
    71. [[-1, -3, 44], 1, Concat, [1]],
    72. [-1, 1, C3C2, [256]],
    73. [-1, 1, MP, []],
    74. [-1, 1, Conv, [256, 1, 1]],
    75. [-3, 1, Conv, [256, 1, 1]],
    76. [-1, 1, Conv, [256, 3, 2]],
    77. [[-1, -3, 39], 1, Concat, [1]],
    78. [-1, 3, C3C2, [512]],
    79. # 检测头 -----------------------------
    80. [49, 1, RepConv, [256, 3, 1]],
    81. [55, 1, RepConv, [512, 3, 1]],
    82. [61, 1, RepConv, [1024, 3, 1]],
    83. [[62,63,64], 1, IDetect, [nc, anchors]], # Detect(P3, P4, P5)
    84. ]

    2.2common.py配置

    ./models/common.py文件增加以下模块

    1. import numpy as np
    2. import torch
    3. from torch import nn
    4. from torch.nn import init
    5. class GAMAttention(nn.Module):
    6. #https://paperswithcode.com/paper/global-attention-mechanism-retain-information
    7. def __init__(self, c1, c2, group=True,rate=4):
    8. super(GAMAttention, self).__init__()
    9. self.channel_attention = nn.Sequential(
    10. nn.Linear(c1, int(c1 / rate)),
    11. nn.ReLU(inplace=True),
    12. nn.Linear(int(c1 / rate), c1)
    13. )
    14. self.spatial_attention = nn.Sequential(
    15. nn.Conv2d(c1, c1//rate, kernel_size=7, padding=3,groups=rate)if group else nn.Conv2d(c1, int(c1 / rate), kernel_size=7, padding=3),
    16. nn.BatchNorm2d(int(c1 /rate)),
    17. nn.ReLU(inplace=True),
    18. nn.Conv2d(c1//rate, c2, kernel_size=7, padding=3,groups=rate) if group else nn.Conv2d(int(c1 / rate), c2, kernel_size=7, padding=3),
    19. nn.BatchNorm2d(c2)
    20. )
    21. def forward(self, x):
    22. b, c, h, w = x.shape
    23. x_permute = x.permute(0, 2, 3, 1).view(b, -1, c)
    24. x_att_permute = self.channel_attention(x_permute).view(b, h, w, c)
    25. x_channel_att = x_att_permute.permute(0, 3, 1, 2)
    26. x = x * x_channel_att
    27. x_spatial_att = self.spatial_attention(x).sigmoid()
    28. x_spatial_att=channel_shuffle(x_spatial_att,4) #last shuffle
    29. out = x * x_spatial_att
    30. return out
    31. def channel_shuffle(x, groups=2):
    32. B, C, H, W = x.size()
    33. out = x.view(B, groups, C // groups, H, W).permute(0, 2, 1, 3, 4).contiguous()
    34. out=out.view(B, C, H, W)
    35. return out

    2.3yolo.py配置

    在 models/yolo.py文件夹下

    • 定位到parse_model函数中
    • for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']):内部
    • 对应位置 下方只需要新增以下代码
    1. elif m is GAMAttention:
    2. c1, c2 = ch[f], args[0]
    3. if c2 != no:
    4. c2 = make_divisible(c2 * gw, 8)
    5. args = [c1, c2, *args[1:]]

    修改完成

  • 相关阅读:
    域策略(7)——禁用本地administrator登录计算机
    Spring事务
    电力电子转战数字IC20220718-19day51-52——TLM通信
    前端周刊第三十期
    21天学习挑战:经典算法---顺序查找
    关于scanf和printf的格式控制修饰符
    山东大学2024深度学习期末考试回忆
    上周热点回顾(11.28-12.4)
    redis做缓存,mysql的数据怎么与redis进行同步(双写一致性)
    spoken english
  • 原文地址:https://blog.csdn.net/weixin_45303602/article/details/133394753