• yolov5 优化系列(三):修改损失函数


    1.使用 Focal loss

    在util/loss.py中,computeloss类用于计算损失函数

    # Focal loss
            g = h['fl_gamma']  # focal loss gamma
            if g > 0:
                BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
    
    • 1
    • 2
    • 3
    • 4

    其中这一段就是开启Focal loss的关键!!!

    parser.add_argument('--hyp', type=str, default=ROOT / 'data/hyps/hyp.scratch-low.yaml', help='hyperparameters path')
     
    
    • 1
    • 2

    使用的data/hyps/hyp.scratch-low.yaml为参数配置文件,进去修改fl_gamma即可

    在这里插入图片描述

    fl_gamma实际上就是公式中红色椭圆的部分
    看看代码更易于理解:

     def forward(self, pred, true):
            loss = self.loss_fcn(pred, true)
            # p_t = torch.exp(-loss)
            # loss *= self.alpha * (1.000001 - p_t) ** self.gamma  # non-zero power for gradient stability
    
            # TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
            pred_prob = torch.sigmoid(pred)  # prob from logits
            p_t = true * pred_prob + (1 - true) * (1 - pred_prob)
            alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
            modulating_factor = (1.0 - p_t) ** self.gamma
            loss *= alpha_factor * modulating_factor
    
            if self.reduction == 'mean':
                return loss.mean()
            elif self.reduction == 'sum':
                return loss.sum()
            else:  # 'none'
                return loss
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18

    调参上的技巧
    在这里插入图片描述

    1.1 增加alpha

    focalloss其实是两个参数,一个参数就是我们前述的fl_gamma,同样的道理我们也可以增加fl_alpha来调节alpha参数
    (1)进入参数配置文件
    请添加图片描述
    增加

    fl_alpha: 0.95     # my focal loss alpha:nagetive example rate
    
    • 1

    (2)然后回到核心代码那里替换这一段

            # Focal loss
            g = h['fl_gamma']  # focal loss gamma
            if g > 0:
                a=h['fl_alpha']
                BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
    
                # ————————————————使用Varifocal Loss损失函数———————————————————————————————————
                #BCEcls, BCEobj = VFLoss(BCEcls, g,a), VFLoss(BCEobj, g,a)
                # print(BCEcls)
                # print
                # ————————————————使用Varifocal Loss损失函数———————————————————————————————————
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11

    Varifocal 和foacl loss二选一,另一个注释掉就行

    (2)使用Varifocal Loss

    Varifocal Loss

    在这里插入图片描述

    p输入为前景类的预测概率;q为ground-truth

    class VFLoss(nn.Module):
        def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
            super(VFLoss, self).__init__()
            # 传递 nn.BCEWithLogitsLoss() 损失函数  must be nn.BCEWithLogitsLoss()
            self.loss_fcn = loss_fcn  #
            self.gamma = gamma
            self.alpha = alpha
            self.reduction = loss_fcn.reduction
            self.loss_fcn.reduction = 'mean'  # required to apply VFL to each element
    
        def forward(self, pred, true):
    
            loss = self.loss_fcn(pred, true)
    
            pred_prob = torch.sigmoid(pred)  # prob from logits
                                                                        #p
            focal_weight = true * (true > 0.0).float() + self.alpha * (pred_prob - true).abs().pow(self.gamma) * (
                        true <= 0.0).float()
            loss *= focal_weight
    
            if self.reduction == 'mean':
                return loss.mean()
            elif self.reduction == 'sum':
                return loss.sum()
            else:
                return loss
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26

    true:q,即为ground-truth
    (pred_prob - true):p,即前景类的预测概率

    直接使用代码会报这个错
    在这里插入图片描述
    后面self.loss_fcn.reduction = 'mean'修改为self.loss_fcn.reduction = 'none'就没问题了

    Focal loss和Varifocal Loss始终是不如原先的效果,可能很大一部分是参数问题

  • 相关阅读:
    centos下安装配置redis7
    LVS,Nginx,Haproxy三种负载均衡产品的对比
    企企通:数字化浪潮下,企业如何利用间接采购策略,实现降本增效?
    Introduction To AMBA 简单理解
    入选C/C++领域内容榜45名
    【linux命令讲解大全】076.pgrep命令:查找和列出符合条件的进程ID
    利用C++开发一个迷你的英文单词录入和测试小程序
    827万!朔黄铁路基于5G边缘计算的智慧牵引变电所研究项目
    HTTP的请求方式有哪些?
    【视觉SLAM】Bags of Binary Words for Fast Place Recognition in Image Sequences
  • 原文地址:https://blog.csdn.net/weixin_50862344/article/details/126474702