• 【yolo系列:yolov7改进wise-iou】


    yolo系列文章目录

    学习视频:
    YOLOV7改进-Wise IoU_哔哩哔哩_bilibili

    代码地址:
    objectdetection_script/yolov7-iou.py at master · z1069614715/objectdetection_script (github.com)
    Wise-IoU(WIoU)是一种用于目标检测的创新性损失函数,针对传统边界框损失函数中对训练数据质量要求较高的问题进行了改进。在目标检测中,边界框损失函数的设计对模型性能至关重要。以往的研究大多假定训练数据是高质量的,并试图通过强化边界框损失的拟合能力来提高模型性能。然而,在实际训练集中,通常包含了一些低质量的示例,如果盲目地加强对这些低质量示例的回归,可能会损害模型的检测性能。

    为了解决这个问题,先前的研究提出了Focal-EIoU v1方法,但其聚焦机制是静态的,未充分挖掘非单调聚焦机制的潜力。基于这一观点,研究者们提出了一种新的动态非单调聚焦机制,即Wise-IoU(WIoU)。这种机制使用“离群度”替代传统的IoU(Intersection over Union)来评估锚框的质量,并引入了明智的梯度增益分配策略。这一策略在降低高质量锚框竞争性的同时,也减小了低质量示例产生的有害梯度。这使得WIoU能够更集中地处理普通质量的锚框,从而提高整体检测器的性能。

    在实际应用中,将WIoU应用于最先进的单级检测器YOLOv7时,它在MS-COCO数据集上的AP-75(Average Precision with IoU threshold at 0.75)从53.03%提升到了54.50%。这种显著的性能提升表明Wise-IoU在处理目标检测任务中具有很高的实用性和效果。通过引入动态非单调聚焦机制,WIoU为目标检测领域带来了新的思路和方法,为提高模型的鲁棒性和准确性提供了有力支持。



    一、在yolov7之上进行替换

    utils/general.py替换bbiox

    class WIoU_Scale:
        ''' monotonous: {
                None: origin v1
                True: monotonic FM v2
                False: non-monotonic FM v3
            }
            momentum: The momentum of running mean'''
        
        iou_mean = 1.
        monotonous = False
        _momentum = 1 - 0.5 ** (1 / 7000)
        _is_train = True
    
        def __init__(self, iou):
            self.iou = iou
            self._update(self)
        
        @classmethod
        def _update(cls, self):
            if cls._is_train: cls.iou_mean = (1 - cls._momentum) * cls.iou_mean + \
                                             cls._momentum * self.iou.detach().mean().item()
        
        @classmethod
        def _scaled_loss(cls, self, gamma=1.9, delta=3):
            if isinstance(self.monotonous, bool):
                if self.monotonous:
                    return (self.iou.detach() / self.iou_mean).sqrt()
                else:
                    beta = self.iou.detach() / self.iou_mean
                    alpha = delta * torch.pow(gamma, beta - delta)
                    return beta / alpha
            return 1
        
    
    def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, SIoU=False, EIoU=False, WIoU=False, Focal=False, alpha=1, gamma=0.5, scale=False, eps=1e-7):
        # Returns Intersection over Union (IoU) of box1(1,4) to box2(n,4)
    
        # Get the coordinates of bounding boxes
        if xywh:  # transform from xywh to xyxy
            (x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1)
            w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2
            b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_
            b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_
        else:  # x1, y1, x2, y2 = box1
            b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)
            b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)
            w1, h1 = b1_x2 - b1_x1, (b1_y2 - b1_y1).clamp(eps)
            w2, h2 = b2_x2 - b2_x1, (b2_y2 - b2_y1).clamp(eps)
    
        # Intersection area
        inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp(0) * \
                (b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)).clamp(0)
    
        # Union Area
        union = w1 * h1 + w2 * h2 - inter + eps
        if scale:
            self = WIoU_Scale(1 - (inter / union))
    
        # IoU
        # iou = inter / union # ori iou
        iou = torch.pow(inter/(union + eps), alpha) # alpha iou
        if CIoU or DIoU or GIoU or EIoU or SIoU or WIoU:
            cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1)  # convex (smallest enclosing box) width
            ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1)  # convex height
            if CIoU or DIoU or EIoU or SIoU or WIoU:  # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
                c2 = (cw ** 2 + ch ** 2) ** alpha + eps  # convex diagonal squared
                rho2 = (((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4) ** alpha  # center dist ** 2
                if CIoU:  # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
                    v = (4 / math.pi ** 2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)
                    with torch.no_grad():
                        alpha_ciou = v / (v - iou + (1 + eps))
                    if Focal:
                        return iou - (rho2 / c2 + torch.pow(v * alpha_ciou + eps, alpha)), torch.pow(inter/(union + eps), gamma)  # Focal_CIoU
                    else:
                        return iou - (rho2 / c2 + torch.pow(v * alpha_ciou + eps, alpha))  # CIoU
                elif EIoU:
                    rho_w2 = ((b2_x2 - b2_x1) - (b1_x2 - b1_x1)) ** 2
                    rho_h2 = ((b2_y2 - b2_y1) - (b1_y2 - b1_y1)) ** 2
                    cw2 = torch.pow(cw ** 2 + eps, alpha)
                    ch2 = torch.pow(ch ** 2 + eps, alpha)
                    if Focal:
                        return iou - (rho2 / c2 + rho_w2 / cw2 + rho_h2 / ch2), torch.pow(inter/(union + eps), gamma) # Focal_EIou
                    else:
                        return iou - (rho2 / c2 + rho_w2 / cw2 + rho_h2 / ch2) # EIou
                elif SIoU:
                    # SIoU Loss https://arxiv.org/pdf/2205.12740.pdf
                    s_cw = (b2_x1 + b2_x2 - b1_x1 - b1_x2) * 0.5 + eps
                    s_ch = (b2_y1 + b2_y2 - b1_y1 - b1_y2) * 0.5 + eps
                    sigma = torch.pow(s_cw ** 2 + s_ch ** 2, 0.5)
                    sin_alpha_1 = torch.abs(s_cw) / sigma
                    sin_alpha_2 = torch.abs(s_ch) / sigma
                    threshold = pow(2, 0.5) / 2
                    sin_alpha = torch.where(sin_alpha_1 > threshold, sin_alpha_2, sin_alpha_1)
                    angle_cost = torch.cos(torch.arcsin(sin_alpha) * 2 - math.pi / 2)
                    rho_x = (s_cw / cw) ** 2
                    rho_y = (s_ch / ch) ** 2
                    gamma = angle_cost - 2
                    distance_cost = 2 - torch.exp(gamma * rho_x) - torch.exp(gamma * rho_y)
                    omiga_w = torch.abs(w1 - w2) / torch.max(w1, w2)
                    omiga_h = torch.abs(h1 - h2) / torch.max(h1, h2)
                    shape_cost = torch.pow(1 - torch.exp(-1 * omiga_w), 4) + torch.pow(1 - torch.exp(-1 * omiga_h), 4)
                    if Focal:
                        return iou - torch.pow(0.5 * (distance_cost + shape_cost) + eps, alpha), torch.pow(inter/(union + eps), gamma) # Focal_SIou
                    else:
                        return iou - torch.pow(0.5 * (distance_cost + shape_cost) + eps, alpha) # SIou
                elif WIoU:
                    if Focal:
                        raise Exception("WIoU do not support Focal.")
                    elif scale:
                        return getattr(WIoU_Scale, '_scaled_loss')(self), (1 - iou) * torch.exp((rho2 / c2)), iou # WIoU https://arxiv.org/abs/2301.10051
                    else:
                        return iou, torch.exp((rho2 / c2)) # WIoU v1
                if Focal:
                    return iou - rho2 / c2, torch.pow(inter/(union + eps), gamma)  # Focal_DIoU
                else:
                    return iou - rho2 / c2  # DIoU
            c_area = cw * ch + eps  # convex area
            if Focal:
                return iou - torch.pow((c_area - union) / c_area + eps, alpha), torch.pow(inter/(union + eps), gamma)  # Focal_GIoU https://arxiv.org/pdf/1902.09630.pdf
            else:
                return iou - torch.pow((c_area - union) / c_area + eps, alpha)  # GIoU https://arxiv.org/pdf/1902.09630.pdf
        if Focal:
            return iou, torch.pow(inter/(union + eps), gamma)  # Focal_IoU
        else:
            return iou  # IoU
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125

    二、在loss.py的ComputeLoss,ComputeLossOTA修改如下

    在这里插入图片描述

    if type(iou) is tuple:
        if len(iou) == 2:
            lbox += (iou[1].detach() * (1 - iou[0].)).mean()
            iou = iou[0]
        else:
            lbox += (iou[0] * iou[1]).mean()
            iou = iou[-1]
    else:
        lbox += (1.0 - iou.squeeze()).mean()  # iou loss
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    修改在这里插入图片描述

    在这里插入图片描述

    三、设置版本

    monotonous = False
    就是v3,truev2,none为v1版本,可以自行尝试效果。
    在目标检测领域,选择合适的模型版本对性能提升至关重要。在这个实验中,我们探讨了三个不同版本的模型:v1、v2、v3,它们分别代表了monotonous参数设置为None、True和False的情况。这个参数的变化引入了不同的聚焦机制,从而影响了模型的性能表现。

    首先,当monotonous参数为None(v1版本)时,模型的聚焦机制是静态的,无法适应数据中的复杂特征变化。这可能导致模型无法有效地捕捉到目标边界的微小变化,从而影响了检测的准确性。

    其次,monotonous参数为True(v2版本)时,模型采用了单调递增的聚焦机制。这种机制对于某些场景可能更加适用,但在某些情况下可能会忽略掉一些关键的目标特征,导致性能提升有限。

    最后,monotonous参数为False(v3版本)时,引入了非单调递增的聚焦机制。这种机制允许模型更加灵活地适应各种特定目标的形状和结构,从而在处理复杂场景时表现更为出色。

    在实验过程中,我们可以根据不同版本的模型输出结果进行性能对比分析。通过比较各个版本在各种测试场景下的检测准确性、鲁棒性和处理速度,我们可以确定哪个版本在实验中取得了明显的提升。

    总结而言,选择适当的聚焦机制非常关键,它直接影响了目标检测模型的性能。在实验中,我们可以通过调整monotonous参数来尝试不同版本的模型,从而找到最适合特定任务和数据集的版本。这种灵活性使得我们能够根据实际需求优化模型,取得更好的检测结果。在选择模型版本时,综合考虑准确性、鲁棒性和处理速度等因素,可以帮助我们做出明智的决策,提高目标检测系统的性能和可靠性。


    总结

    确定好训练配置后,即可进行性能对比分析,找出哪个版本在实验中取得了明显的提升。

  • 相关阅读:
    c# sqlite 修改字段类型
    Java基础
    谣言检查论文精读——5.SpotFake: A Multi-modal Framework for Fake News Detection
    JVM类装载器详解
    css:img引入svg后修改颜色
    软件工程师和程序员到底有多大的区别?
    基于Java+SpringBoot+Vue前后端分离青年公寓服务平台设计和实现
    MySQL事务隔离与行锁的关系
    【C++】单例模式
    【老生谈算法】matlab实现自适应对消器的LMS算法——LMS算法
  • 原文地址:https://blog.csdn.net/weixin_47869094/article/details/133654505