• YOLOv8最新改进系列:YOLOv8改进之添加注意力-ContextAggregation,有效涨点!!!


    YOLOv8改进:添加注意力-ContextAggregation

    一、更改 yaml文件
    二、新建ContextAggregation.py
    三、更改 tasks.py
    详细改进流程和操作,请关注B站博主:AI学术叫叫兽

    相关源码已在B站:AI学术叫叫兽
    上架!!!!科研搞起来!表情包
    论文地址在这

    一、更改yaml文件

    已完成更改的yaml文件如下所示,更改了两处哈.

    # Ultralytics YOLO 🚀, AGPL-3.0 license
    # YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
    
    # Parameters
    nc: 80  # number of classes
    scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
      # [depth, width, max_channels]
      n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
      s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
      m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
      l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
      x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs
    
    # YOLOv8.0n backbone
    backbone:
      # [from, repeats, module, args]
      - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2
      - [-1, 1, GhostConv, [128, 3, 2]]  # 1-P2/4
      - [-1, 3, C2f, [128, True]]
      - [-1, 1, GhostConv, [256, 3, 2]]  # 3-P3/8
      - [-1, 6, C2f, [256, True]]
      - [-1, 1, GhostConv, [512, 3, 2]]  # 5-P4/16
      - [-1, 6, C2f, [512, True]]
      - [-1, 1, GhostConv, [1024, 3, 2]]  # 7-P5/32
      - [-1, 3, C2f, [1024, True]]
      - [-1, 1, SPPF, [1024, 5]]  # 9
    #详细改进流程和操作,请关注B站博主:AI学术叫叫兽 
    # YOLOv8.0n head
    head:
      - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
      - [[-1, 6], 1, Concat, [1]]  # cat backbone P4
      - [-1, 3, C2f, [512]]  # 12
    
      - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
      - [[-1, 4], 1, Concat, [1]]  # cat backbone P3
      - [-1, 3, C2f, [256]]  # 15 (P3/8-small)
    
      - [-1, 1, Conv, [256, 3, 2]]
      - [[-1, 12], 1, Concat, [1]]  # cat head P4
      - [-1, 3, C2f, [512]]  # 18 (P4/16-medium)
    #详细改进流程和操作,请关注B站博主:AI学术叫叫兽
      - [-1, 1,ContextAggregation, [512]]
      - [-1, 1, Conv, [512, 3, 2]]
      - [[-1, 9], 1, Concat, [1]]  # cat head P5
      - [-1, 3, C2f, [1024]]  # 21 (P5/32-large)
      - [-1, 1,ContextAggregation, [1024]]
      - [[15, 18, 21], 1, Detect, [nc]]  # Detect(P3, P4, P5)
    #详细改进流程和操作,请关注B站博主:AI学术叫叫兽
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48

    新建 ContextAggregation.py

    #详细改进流程和操作,请关注B站博主:AI学术叫叫兽  持续更新哦
    
    
    import torch
    import torch.nn as nn
    from mmcv.cnn import ConvModule, caffe2_xavier_init, constant_init
    #详细改进流程和操作,请关注B站博主:AI学术叫叫兽 
    
    from mmcv.cnn import ConvModule
    #详细改进流程和操作,请关注B站博主:AI学术叫叫兽 
     
    class ContextAggregation(nn.Module):
    #详细改进流程和操作,请关注B站博主:AI学术叫叫兽 
     
        def __init__(self, in_channels, reduction=1, conv_cfg=None):
            super(ContextAggregation, self).__init__()
            self.in_channels = in_channels
            self.reduction = reduction
            self.inter_channels = max(in_channels // reduction, 1)
     
            conv_params = dict(kernel_size=1, conv_cfg=conv_cfg, act_cfg=None)
     
            self.a = ConvModule(in_channels, 1, **conv_params)
            self.k = ConvModule(in_channels, 1, **conv_params)
            self.v = ConvModule(in_channels, self.inter_channels, **conv_params)
            self.m = ConvModule(self.inter_channels, in_channels, **conv_params)
     
            self.init_weights()
     
        def init_weights(self):
            for m in (self.a, self.k, self.v):
                caffe2_xavier_init(m.conv)
            constant_init(self.m.conv, 0)
     
        def forward(self, x):
            n, c = x.size(0), self.inter_channels
     
            # a: [N, 1, H, W]
            a = self.a(x).sigmoid()
     
            # k: [N, 1, HW, 1]
            k = self.k(x).view(n, 1, -1, 1).softmax(2)
     
            # v: [N, 1, C, HW]
            v = self.v(x).view(n, 1, c, -1)
     
            # y: [N, C, 1, 1]
            y = torch.matmul(v, k).view(n, c, 1, 1)
            y = self.m(y) * a
     
            return x + y
    #详细改进流程和操作,请关注B站博主:AI学术叫叫兽 片
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52

    更改 tasks.py

    找到tasks.py中的此代码,替换即可,大约在650行左右。

     if m in (Classify, Conv, GGhostRegNet, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus,
                     BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x, RepC3, SEAttention,ContextAggregation):
                c1, c2 = ch[f], args[0]
                if c2 != nc:  # if c2 not equal to number of classes (i.e. for Classify() output)
                    c2 = make_divisible(min(c2, max_channels) * width, 8)
    
    • 1
    • 2
    • 3
    • 4
    • 5

    检查修改是否成功

    跑通后是这样的哦~
    在这里插入图片描述

    注意!

    别忘喽~关注B站博主:AI学术叫叫兽
    往期B站视频已经更新了四层检测层,如果注意力加四个检测头,会发生什么?快动手去试试!

    科研搞起来!一Giao窝里Giao Giao!!

    已经更新了 注意力、特征提取网络、添加检测头、优化卷积操作等改进方法。
    改进方法持续更新,应B站粉丝要求,近期会在B站开设论文写作方面的专栏。

  • 相关阅读:
    微信小程序环境搭建
    如何每天自动发送心灵鸡汤、正能量语录
    Java图书管理系统实训报告
    基于STM32设计的便携式心电信号监测系统
    新零售项目及离线数仓核心面试,,220807,,
    JAVA毕业设计宠物销售网站计算机源码+lw文档+系统+调试部署+数据库
    Web APIs第01天笔记——Web API介绍
    【STM32学习(3)】STM32——简述中断的基础知识
    异步调用中的问题
    人工智能前沿——AI技术在医疗领域的应用(二)
  • 原文地址:https://blog.csdn.net/weixin_51692073/article/details/132621184