• 【YOLO改进】主干插入SKAttention模块(基于MMYOLO)


    SKAttention模块

    论文链接:https://arxiv.org/pdf/1903.06586.pdf

    将SKAttention模块添加到MMYOLO中

    1. 将开源代码SK.py文件复制到mmyolo/models/plugins目录下

    2. 导入MMYOLO用于注册模块的包: from mmyolo.registry import MODELS

    3. 确保 class SKAttention中的输入维度为in_channels(因为MMYOLO会提前传入输入维度参数,所以要保持参数名的一致)

    4. 利用@MODELS.register_module()将“class SKAttention(nn.Module)”注册:

    5. 修改mmyolo/models/plugins/__init__.py文件

    6. 在终端运行:

      python setup.py install
    7. 修改对应的配置文件,并且将plugins的参数“type”设置为“BiLevelRoutingAttention”,可参考【YOLO改进】主干插入注意力机制模块CBAM(基于MMYOLO)-CSDN博客

    修改后的SK.py

    1. from collections import OrderedDict
    2. import torch
    3. from torch import nn
    4. from mmyolo.registry import MODELS
    5. @MODELS.register_module()
    6. class SKAttention(nn.Module):
    7. def __init__(self, in_channels=512, kernels=[1, 3, 5, 7], reduction=16, group=1, L=32):
    8. super().__init__()
    9. self.d = max(L, in_channels // reduction)
    10. self.convs = nn.ModuleList([])
    11. for k in kernels:
    12. self.convs.append(
    13. nn.Sequential(OrderedDict([
    14. ('conv', nn.Conv2d(in_channels, in_channels, kernel_size=k, padding=k // 2, groups=group)),
    15. ('bn', nn.BatchNorm2d(in_channels)),
    16. ('relu', nn.ReLU())
    17. ]))
    18. )
    19. self.fc = nn.Linear(in_channels, self.d)
    20. self.fcs = nn.ModuleList([])
    21. for i in range(len(kernels)):
    22. self.fcs.append(nn.Linear(self.d, in_channels))
    23. self.softmax = nn.Softmax(dim=0)
    24. def forward(self, x):
    25. bs, c, _, _ = x.size()
    26. conv_outs = []
    27. ### split
    28. for conv in self.convs:
    29. conv_outs.append(conv(x))
    30. feats = torch.stack(conv_outs, 0) # k,bs,channel,h,w
    31. ### fuse
    32. U = sum(conv_outs) # bs,c,h,w
    33. ### reduction channel
    34. S = U.mean(-1).mean(-1) # bs,c
    35. Z = self.fc(S) # bs,d
    36. ### calculate attention weight
    37. weights = []
    38. for fc in self.fcs:
    39. weight = fc(Z)
    40. weights.append(weight.view(bs, c, 1, 1)) # bs,channel
    41. attention_weughts = torch.stack(weights, 0) # k,bs,channel,1,1
    42. attention_weughts = self.softmax(attention_weughts) # k,bs,channel,1,1
    43. ### fuse
    44. V = (attention_weughts * feats).sum(0)
    45. return V
    46. if __name__ == '__main__':
    47. input = torch.randn(50, 512, 7, 7)
    48. se = SKAttention(in_channels=512, reduction=8)
    49. output = se(input)
    50. print(output.shape)

    修改后的__init__.py

    1. # Copyright (c) OpenMMLab. All rights reserved.
    2. from .cbam import CBAM
    3. from .Biformer import BiLevelRoutingAttention
    4. from .A2Attention import DoubleAttention
    5. from .CoordAttention import CoordAtt
    6. from .CoTAttention import CoTAttention
    7. from .ECA import ECAAttention
    8. from .EffectiveSE import EffectiveSEModule
    9. from .EMA import EMA
    10. from .GC import GlobalContext
    11. from .GE import GatherExcite
    12. from .MHSA import MHSA
    13. from .ParNetAttention import ParNetAttention
    14. from .PolarizedSelfAttention import ParallelPolarizedSelfAttention
    15. from .S2Attention import S2Attention
    16. from .SE import SEAttention
    17. from .SequentialSelfAttention import SequentialPolarizedSelfAttention
    18. from .SGE import SpatialGroupEnhance
    19. from .ShuffleAttention import ShuffleAttention
    20. from .SimAM import SimAM
    21. from .SK import SKAttention
    22. __all__ = ['CBAM', 'BiLevelRoutingAttention', 'DoubleAttention', 'CoordAtt','CoTAttention','ECAAttention', 'EffectiveSEModule', 'EMA',
    23. 'GlobalContext', 'GatherExcite', 'MHSA', 'ParNetAttention','ParallelPolarizedSelfAttention','S2Attention','SEAttention',
    24. 'SequentialPolarizedSelfAttention','SpatialGroupEnhance','ShuffleAttention','SimAM','SKAttention']

    修改后的配置文件(以configs/yolov5/yolov5_s-v61_syncbn_8xb16-300e_coco.py为例)

    1. _base_ = ['../_base_/default_runtime.py', '../_base_/det_p5_tta.py']
    2. # ========================Frequently modified parameters======================
    3. # -----data related-----
    4. data_root = 'data/coco/' # Root path of data
    5. # Path of train annotation file
    6. train_ann_file = 'annotations/instances_train2017.json'
    7. train_data_prefix = 'train2017/' # Prefix of train image path
    8. # Path of val annotation file
    9. val_ann_file = 'annotations/instances_val2017.json'
    10. val_data_prefix = 'val2017/' # Prefix of val image path
    11. num_classes = 80 # Number of classes for classification
    12. # Batch size of a single GPU during training
    13. train_batch_size_per_gpu = 16
    14. # Worker to pre-fetch data for each single GPU during training
    15. train_num_workers = 8
    16. # persistent_workers must be False if num_workers is 0
    17. persistent_workers = True
    18. # -----model related-----
    19. # Basic size of multi-scale prior box
    20. anchors = [
    21. [(10, 13), (16, 30), (33, 23)], # P3/8
    22. [(30, 61), (62, 45), (59, 119)], # P4/16
    23. [(116, 90), (156, 198), (373, 326)] # P5/32
    24. ]
    25. # -----train val related-----
    26. # Base learning rate for optim_wrapper. Corresponding to 8xb16=128 bs
    27. base_lr = 0.01
    28. max_epochs = 300 # Maximum training epochs
    29. model_test_cfg = dict(
    30. # The config of multi-label for multi-class prediction.
    31. multi_label=True,
    32. # The number of boxes before NMS
    33. nms_pre=30000,
    34. score_thr=0.001, # Threshold to filter out boxes.
    35. nms=dict(type='nms', iou_threshold=0.65), # NMS type and threshold
    36. max_per_img=300) # Max number of detections of each image
    37. # ========================Possible modified parameters========================
    38. # -----data related-----
    39. img_scale = (640, 640) # width, height
    40. # Dataset type, this will be used to define the dataset
    41. dataset_type = 'YOLOv5CocoDataset'
    42. # Batch size of a single GPU during validation
    43. val_batch_size_per_gpu = 1
    44. # Worker to pre-fetch data for each single GPU during validation
    45. val_num_workers = 2
    46. # Config of batch shapes. Only on val.
    47. # It means not used if batch_shapes_cfg is None.
    48. batch_shapes_cfg = dict(
    49. type='BatchShapePolicy',
    50. batch_size=val_batch_size_per_gpu,
    51. img_size=img_scale[0],
    52. # The image scale of padding should be divided by pad_size_divisor
    53. size_divisor=32,
    54. # Additional paddings for pixel scale
    55. extra_pad_ratio=0.5)
    56. # -----model related-----
    57. # The scaling factor that controls the depth of the network structure
    58. deepen_factor = 0.33
    59. # The scaling factor that controls the width of the network structure
    60. widen_factor = 0.5
    61. # Strides of multi-scale prior box
    62. strides = [8, 16, 32]
    63. num_det_layers = 3 # The number of model output scales
    64. norm_cfg = dict(type='BN', momentum=0.03, eps=0.001) # Normalization config
    65. # -----train val related-----
    66. affine_scale = 0.5 # YOLOv5RandomAffine scaling ratio
    67. loss_cls_weight = 0.5
    68. loss_bbox_weight = 0.05
    69. loss_obj_weight = 1.0
    70. prior_match_thr = 4. # Priori box matching threshold
    71. # The obj loss weights of the three output layers
    72. obj_level_weights = [4., 1., 0.4]
    73. lr_factor = 0.01 # Learning rate scaling factor
    74. weight_decay = 0.0005
    75. # Save model checkpoint and validation intervals
    76. save_checkpoint_intervals = 10
    77. # The maximum checkpoints to keep.
    78. max_keep_ckpts = 3
    79. # Single-scale training is recommended to
    80. # be turned on, which can speed up training.
    81. env_cfg = dict(cudnn_benchmark=True)
    82. # ===============================Unmodified in most cases====================
    83. model = dict(
    84. type='YOLODetector',
    85. data_preprocessor=dict(
    86. type='mmdet.DetDataPreprocessor',
    87. mean=[0., 0., 0.],
    88. std=[255., 255., 255.],
    89. bgr_to_rgb=True),
    90. backbone=dict(
    91. ##修改部分
    92. plugins=[
    93. dict(cfg=dict(type='SKAttention'),
    94. stages=(False, False, False, True))
    95. ],
    96. type='YOLOv5CSPDarknet',
    97. deepen_factor=deepen_factor,
    98. widen_factor=widen_factor,
    99. norm_cfg=norm_cfg,
    100. act_cfg=dict(type='SiLU', inplace=True)
    101. ),
    102. neck=dict(
    103. type='YOLOv5PAFPN',
    104. deepen_factor=deepen_factor,
    105. widen_factor=widen_factor,
    106. in_channels=[256, 512, 1024],
    107. out_channels=[256, 512, 1024],
    108. num_csp_blocks=3,
    109. norm_cfg=norm_cfg,
    110. act_cfg=dict(type='SiLU', inplace=True)),
    111. bbox_head=dict(
    112. type='YOLOv5Head',
    113. head_module=dict(
    114. type='YOLOv5HeadModule',
    115. num_classes=num_classes,
    116. in_channels=[256, 512, 1024],
    117. widen_factor=widen_factor,
    118. featmap_strides=strides,
    119. num_base_priors=3),
    120. prior_generator=dict(
    121. type='mmdet.YOLOAnchorGenerator',
    122. base_sizes=anchors,
    123. strides=strides),
    124. # scaled based on number of detection layers
    125. loss_cls=dict(
    126. type='mmdet.CrossEntropyLoss',
    127. use_sigmoid=True,
    128. reduction='mean',
    129. loss_weight=loss_cls_weight *
    130. (num_classes / 80 * 3 / num_det_layers)),
    131. loss_bbox=dict(
    132. type='IoULoss',
    133. iou_mode='ciou',
    134. bbox_format='xywh',
    135. eps=1e-7,
    136. reduction='mean',
    137. loss_weight=loss_bbox_weight * (3 / num_det_layers),
    138. return_iou=True),
    139. loss_obj=dict(
    140. type='mmdet.CrossEntropyLoss',
    141. use_sigmoid=True,
    142. reduction='mean',
    143. loss_weight=loss_obj_weight *
    144. ((img_scale[0] / 640)**2 * 3 / num_det_layers)),
    145. prior_match_thr=prior_match_thr,
    146. obj_level_weights=obj_level_weights),
    147. test_cfg=model_test_cfg)
    148. albu_train_transforms = [
    149. dict(type='Blur', p=0.01),
    150. dict(type='MedianBlur', p=0.01),
    151. dict(type='ToGray', p=0.01),
    152. dict(type='CLAHE', p=0.01)
    153. ]
    154. pre_transform = [
    155. dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),
    156. dict(type='LoadAnnotations', with_bbox=True)
    157. ]
    158. train_pipeline = [
    159. *pre_transform,
    160. dict(
    161. type='Mosaic',
    162. img_scale=img_scale,
    163. pad_val=114.0,
    164. pre_transform=pre_transform),
    165. dict(
    166. type='YOLOv5RandomAffine',
    167. max_rotate_degree=0.0,
    168. max_shear_degree=0.0,
    169. scaling_ratio_range=(1 - affine_scale, 1 + affine_scale),
    170. # img_scale is (width, height)
    171. border=(-img_scale[0] // 2, -img_scale[1] // 2),
    172. border_val=(114, 114, 114)),
    173. dict(
    174. type='mmdet.Albu',
    175. transforms=albu_train_transforms,
    176. bbox_params=dict(
    177. type='BboxParams',
    178. format='pascal_voc',
    179. label_fields=['gt_bboxes_labels', 'gt_ignore_flags']),
    180. keymap={
    181. 'img': 'image',
    182. 'gt_bboxes': 'bboxes'
    183. }),
    184. dict(type='YOLOv5HSVRandomAug'),
    185. dict(type='mmdet.RandomFlip', prob=0.5),
    186. dict(
    187. type='mmdet.PackDetInputs',
    188. meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'flip',
    189. 'flip_direction'))
    190. ]
    191. train_dataloader = dict(
    192. batch_size=train_batch_size_per_gpu,
    193. num_workers=train_num_workers,
    194. persistent_workers=persistent_workers,
    195. pin_memory=True,
    196. sampler=dict(type='DefaultSampler', shuffle=True),
    197. dataset=dict(
    198. type=dataset_type,
    199. data_root=data_root,
    200. ann_file=train_ann_file,
    201. data_prefix=dict(img=train_data_prefix),
    202. filter_cfg=dict(filter_empty_gt=False, min_size=32),
    203. pipeline=train_pipeline))
    204. test_pipeline = [
    205. dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),
    206. dict(type='YOLOv5KeepRatioResize', scale=img_scale),
    207. dict(
    208. type='LetterResize',
    209. scale=img_scale,
    210. allow_scale_up=False,
    211. pad_val=dict(img=114)),
    212. dict(type='LoadAnnotations', with_bbox=True, _scope_='mmdet'),
    213. dict(
    214. type='mmdet.PackDetInputs',
    215. meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
    216. 'scale_factor', 'pad_param'))
    217. ]
    218. val_dataloader = dict(
    219. batch_size=val_batch_size_per_gpu,
    220. num_workers=val_num_workers,
    221. persistent_workers=persistent_workers,
    222. pin_memory=True,
    223. drop_last=False,
    224. sampler=dict(type='DefaultSampler', shuffle=False),
    225. dataset=dict(
    226. type=dataset_type,
    227. data_root=data_root,
    228. test_mode=True,
    229. data_prefix=dict(img=val_data_prefix),
    230. ann_file=val_ann_file,
    231. pipeline=test_pipeline,
    232. batch_shapes_cfg=batch_shapes_cfg))
    233. test_dataloader = val_dataloader
    234. param_scheduler = None
    235. optim_wrapper = dict(
    236. type='OptimWrapper',
    237. optimizer=dict(
    238. type='SGD',
    239. lr=base_lr,
    240. momentum=0.937,
    241. weight_decay=weight_decay,
    242. nesterov=True,
    243. batch_size_per_gpu=train_batch_size_per_gpu),
    244. constructor='YOLOv5OptimizerConstructor')
    245. default_hooks = dict(
    246. param_scheduler=dict(
    247. type='YOLOv5ParamSchedulerHook',
    248. scheduler_type='linear',
    249. lr_factor=lr_factor,
    250. max_epochs=max_epochs),
    251. checkpoint=dict(
    252. type='CheckpointHook',
    253. interval=save_checkpoint_intervals,
    254. save_best='auto',
    255. max_keep_ckpts=max_keep_ckpts))
    256. custom_hooks = [
    257. dict(
    258. type='EMAHook',
    259. ema_type='ExpMomentumEMA',
    260. momentum=0.0001,
    261. update_buffers=True,
    262. strict_load=False,
    263. priority=49)
    264. ]
    265. val_evaluator = dict(
    266. type='mmdet.CocoMetric',
    267. proposal_nums=(100, 1, 10),
    268. ann_file=data_root + val_ann_file,
    269. metric='bbox')
    270. test_evaluator = val_evaluator
    271. train_cfg = dict(
    272. type='EpochBasedTrainLoop',
    273. max_epochs=max_epochs,
    274. val_interval=save_checkpoint_intervals)
    275. val_cfg = dict(type='ValLoop')
    276. test_cfg = dict(type='TestLoop')

  • 相关阅读:
    点云数据转pnts二进制数据
    靶机: medium_socnet
    Java实现图片上传功能(前后端:vue+springBoot)
    java计算机毕业设计影院资源管理系统演示录像2020源程序+mysql+系统+lw文档+远程调试
    React Native 搭建开发环境和创建新项目并运行的详细教程
    季节性壁炉布置:让您的家温馨如冬季仙境
    评价——秩和比综合评价
    国产操作系统生态建设,小程序技术来帮忙
    Cilium系列-9-主机路由切换为基于 BPF 的模式
    Redisson集成SpringBoot
  • 原文地址:https://blog.csdn.net/Vlone_pp/article/details/137525076