• YOLOv10改进教程|C2f-CIB加入注意力机制



      一、 导读

            论文链接:https://arxiv.org/abs/2311.11587

            代码链接:GitHub - CV-ZhangXin/AKConv

     YOLOv10训练、验证及推理教程


    二、 C2f-CIB加入注意力机制

    2.1 复制代码

            打开ultralytics->nn->modules->block.py文件,复制SE注意力机制(也可以自行换成别的)代码,并创建C2fCIBAttention代码,如下图所示:

    1. class SE(nn.Module):
    2. def __init__(self, channel, reduction=16):
    3. super().__init__()
    4. self.avg_pool = nn.AdaptiveAvgPool2d(1)
    5. self.fc = nn.Sequential(
    6. nn.Linear(channel, channel // reduction, bias=False),
    7. nn.ReLU(inplace=True),
    8. nn.Linear(channel // reduction, channel, bias=False),
    9. nn.Sigmoid()
    10. )
    11. def forward(self, x):
    12. b, c, _, _ = x.size()
    13. y = self.avg_pool(x).view(b, c)
    14. y = self.fc(y).view(b, c, 1, 1)
    15. return x * y.expand_as(x)
    16. class C2fCIBAttention(nn.Module):
    17. """Faster Implementation of CSP Bottleneck with 2 convolutions."""
    18. def __init__(self, c1, c2, n=1, shortcut=False, lk=False, g=1, e=0.5):
    19. """Initialize CSP bottleneck layer with two convolutions with arguments ch_in, ch_out, number, shortcut, groups,
    20. expansion.
    21. """
    22. super().__init__()
    23. self.c = int(c2 * e) # hidden channels
    24. self.cv1 = Conv(c1, 2 * self.c, 1, 1)
    25. self.cv2 = Conv((2 + n) * self.c, c2, 1) # optional act=FReLU(c2)
    26. self.m = nn.ModuleList(CIB(self.c, self.c, shortcut, e=1.0, lk=lk) for _ in range(n))
    27. self.atten = SE(C2)
    28. def forward(self, x):
    29. """Forward pass through C2f layer."""
    30. y = list(self.cv1(x).chunk(2, 1))
    31. y.extend(m(y[-1]) for m in self.m)
    32. return self.atten(self.cv2(torch.cat(y, 1)))
    33. def forward_split(self, x):
    34. """Forward pass using split() instead of chunk()."""
    35. y = list(self.cv1(x).split((self.c, self.c), 1))
    36. y.extend(m(y[-1]) for m in self.m)
    37. return self.cv2(torch.cat(y, 1))

            并在上方声明C2fCIBAttention类。

            在nn.models.__init__.py中声明 C2fCIBAttention。

    2.2 修改tasks.py 

           打开ultralytics->nn->tasks.py,如图所示操作。

    ​2.3 修改yolov10n.yaml

            将yolov10n.yaml文件中的C2fCIB替换为C2fCIBAttention。

    1. # Ultralytics YOLO 🚀, AGPL-3.0 license
    2. # YOLOv10 object detection model. For Usage examples see https://docs.ultralytics.com/tasks/detect
    3. # Parameters
    4. nc: 80 # number of classes
    5. scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
    6. # [depth, width, max_channels]
    7. n: [0.33, 0.25, 1024]
    8. backbone:
    9. # [from, repeats, module, args]
    10. - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
    11. - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
    12. - [-1, 3, C2f, [128, True]]
    13. - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
    14. - [-1, 6, C2f, [256, True]]
    15. - [-1, 1, SCDown, [512, 3, 2]] # 5-P4/16
    16. - [-1, 6, C2f, [512, True]]
    17. - [-1, 1, SCDown, [1024, 3, 2]] # 7-P5/32
    18. - [-1, 3, C2f, [1024, True]]
    19. - [-1, 1, SPPF, [1024, 5]] # 9
    20. - [-1, 1, PSA, [1024]] # 10
    21. # YOLOv8.0n head
    22. head:
    23. - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
    24. - [[-1, 6], 1, Concat, [1]] # cat backbone P4
    25. - [-1, 3, C2f, [512]] # 13
    26. - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
    27. - [[-1, 4], 1, Concat, [1]] # cat backbone P3
    28. - [-1, 3, C2f, [256]] # 16 (P3/8-small)
    29. - [-1, 1, Conv, [256, 3, 2]]
    30. - [[-1, 13], 1, Concat, [1]] # cat head P4
    31. - [-1, 3, C2f, [512]] # 19 (P4/16-medium)
    32. - [-1, 1, SCDown, [512, 3, 2]]
    33. - [[-1, 10], 1, Concat, [1]] # cat head P5
    34. - [-1, 3, C2fCIBAttention, [1024, True, True]] # 22 (P5/32-large)
    35. - [[16, 19, 22], 1, v10Detect, [nc]] # Detect(P3, P4, P5)


     2.5 修改train.py文件

            在train.py脚本中填入yolov10n.yaml路径,运行即可训练。


  • 相关阅读:
    大学生实习考勤打卡系统 微信小程序uniapp
    [附源码]计算机毕业设计南通大学福利发放管理系统Springboot程序
    远程办公安全:共同守护数字时代的明日
    jenkins(pipeline)+k8s 实现CICD(提供源码和测试用例)
    DSA之查找(1):线性表的查找
    浅尝Spring注解开发_AOP原理及完整过程分析(源码)
    Webserver解决segmentation fault(core dump)段错问问题
    C/C++ Capstone 引擎源码编译
    数据结构基础内容-----第三章 线性表
    记录做碧桂园项目时后端Java踩的坑
  • 原文地址:https://blog.csdn.net/StopAndGoyyy/article/details/140110212