• yolo增加slide loss,改善样本不平衡问题


    slide loss的主要作用是让模型更加关注难例,可以轻微的改善模型在难例检测上的效果

    论文地址:https://arxiv.org/pdf/2208.02019.pdf

    代码:GitHub - Krasjet-Yu/YOLO-FaceV2: YOLO-FaceV2: A Scale and Occlusion Aware Face Detector

            样本不平衡问题,即在大多数情况下,容易样本的数量很大,而困难样本相对稀疏,引起了很多关注。在本文的工作中,设计了一个看起来像“slide”的Slide Loss函数来解决这个问题。简单样本和困难样本之间的区别是基于预测框和ground truth 框的IoU大小。为了减少超参数,将所有边界框的 IoU 值的平均值作为阈值 µ,小于µ的取负样本,大于µ的取正样本。

            然而,由于分类不明确,边界附近的样本往往会遭受较大的损失。希望模型能够学习优化这些样本,并更充分地使用这些样本来训练网络。然而,此类样本的数量相对较少。因此,尝试为困难样本分配更高的权重。首先通过参数μ将样本分为正样本和负样本。然后,通过加权函数Slide对边界处的样本进行强调,如图 4 所示。Slide加权函数可以表示为公式5。

    在utils/loss.py增加

    1. import math
    2. class SlideLoss(nn.Module):
    3. def __init__(self, loss_fcn):
    4. super(SlideLoss, self).__init__()
    5. self.loss_fcn = loss_fcn
    6. self.reduction = loss_fcn.reduction
    7. self.loss_fcn.reduction = 'none' # required to apply SL to each element
    8. def forward(self, pred, true, auto_iou=0.5):
    9. loss = self.loss_fcn(pred, true)
    10. if auto_iou < 0.2:
    11. auto_iou = 0.2
    12. b1 = true <= auto_iou - 0.1
    13. a1 = 1.0
    14. b2 = (true > (auto_iou - 0.1)) & (true < auto_iou)
    15. a2 = math.exp(1.0 - auto_iou)
    16. b3 = true >= auto_iou
    17. a3 = torch.exp(-(true - 1.0))
    18. modulating_weight = a1 * b1 + a2 * b2 + a3 * b3
    19. loss *= modulating_weight
    20. if self.reduction == 'mean':
    21. return loss.mean()
    22. elif self.reduction == 'sum':
    23. return loss.sum()
    24. else: # 'none'
    25. return loss

    在data\hyps\hyp.scratch-low.yaml中增加

    slide_ratio: 1 # >=1启用slide loss, <1关闭

    在utils/loss.py的ComputeLoss函数中做如下修改:

    1. class ComputeLoss:
    2. # Compute losses
    3. def __init__(self, model, autobalance=False):
    4. super(ComputeLoss, self).__init__()
    5. device = next(model.parameters()).device # get model device
    6. h = model.hyp # hyperparameters
    7. # Define criteria
    8. BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['cls_pw']], device=device))
    9. BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['obj_pw']], device=device))
    10. # Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
    11. self.cp, self.cn = smooth_BCE(eps=h.get('label_smoothing', 0.0)) # positive, negative BCE targets
    12. # slide loss
    13. self.slide_ratio = h['slide_ratio']
    14. if self.slide_ratio > 0:
    15. BCEcls, BCEobj = SlideLoss(BCEcls), SlideLoss(BCEobj)
    16. # Focal loss
    17. g = h['fl_gamma'] # focal loss gamma
    18. if g > 0:
    19. BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
    20. det = model.module.model[-1] if is_parallel(model) else model.model[-1] # Detect() module
    21. self.balance = {3: [4.0, 1.0, 0.4]}.get(det.nl, [4.0, 1.0, 0.25, 0.06, .02]) # P3-P7
    22. self.ssi = list(det.stride).index(16) if autobalance else 0 # stride 16 index
    23. self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, model.gr, h, autobalance
    24. for k in 'na', 'nc', 'nl', 'anchors':
    25. setattr(self, k, getattr(det, k))
    26. def __call__(self, p, targets): # predictions, targets, model
    27. device = targets.device
    28. lcls, lbox, lobj = torch.zeros(1, device=device), torch.zeros(1, device=device), torch.zeros(1, device=device)
    29. lrepBox, lrepGT = torch.zeros(1, device=device), torch.zeros(1, device=device)
    30. tcls, tbox, indices, anchors = self.build_targets(p, targets) # targets
    31. # Losses
    32. for i, pi in enumerate(p): # layer index, layer predictions
    33. b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
    34. tobj = torch.zeros_like(pi[..., 0], device=device) # target obj
    35. n = b.shape[0] # number of targets
    36. if n:
    37. ps = pi[b, a, gj, gi] # prediction subset corresponding to targets
    38. # Regression
    39. pxy = ps[:, :2].sigmoid() * 2. - 0.5
    40. pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]
    41. pbox = torch.cat((pxy, pwh), 1) # predicted box
    42. iou = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, CIoU=True) # iou(prediction, target)
    43. auto_iou = iou.mean()
    44. lbox += (1.0 - iou).mean() # iou loss
    45. # Objectness
    46. tobj[b, a, gj, gi] = (1.0 - self.gr) + self.gr * iou.detach().clamp(0).type(tobj.dtype) # iou ratio
    47. # Classification
    48. if self.nc > 1: # cls loss (only if multiple classes)
    49. t = torch.full_like(ps[:, 5:], self.cn, device=device) # targets
    50. t[range(n), tcls[i]] = self.cp
    51. if self.slide_ratio > 0:
    52. lcls += self.BCEcls(ps[:, 5:], t, auto_iou) # BCE
    53. else:
    54. lcls += self.BCEcls(ps[:, 5:], t) # BCE
    55. # Append targets to text file
    56. # with open('targets.txt', 'a') as file:
    57. # [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]
    58. if self.slide_ratio > 0 and n:
    59. obji = self.BCEobj(pi[..., 4], tobj, auto_iou)
    60. else:
    61. obji = self.BCEobj(pi[..., 4], tobj)
    62. lobj += obji * self.balance[i] # obj loss
    63. if self.autobalance:
    64. self.balance[i] = self.balance[i] * 0.9999 + 0.0001 / obji.detach().item()
    65. if self.autobalance:
    66. self.balance = [x / self.balance[self.ssi] for x in self.balance]
    67. lbox *= self.hyp['box']
    68. lobj *= self.hyp['obj']
    69. lcls *= self.hyp['cls']
    70. bs = tobj.shape[0] # batch size
    71. loss = lbox + lobj + lcls
    72. return loss * bs, torch.cat((lbox, lobj, lcls, loss)).detach()

     主要修改如下:

    1、__init__中增加

    1. # slide loss
    2. self.slide_ratio = h['slide_ratio']
    3. if self.slide_ratio > 0:
    4. BCEcls, BCEobj = SlideLoss(BCEcls), SlideLoss(BCEobj)

    2、计算完iou后增加

    auto_iou = iou.mean()

    3、在类别损失函数上

    1. if self.slide_ratio > 0:
    2. lcls += self.BCEcls(ps[:, 5:], t, auto_iou) # BCE
    3. else:
    4. lcls += self.BCEcls(ps[:, 5:], t) # BCE

    4、前背景损失函数上

    1. if self.slide_ratio > 0 and n:
    2. obji = self.BCEobj(pi[..., 4], tobj, auto_iou)
    3. else:
    4. obji = self.BCEobj(pi[..., 4], tobj)
  • 相关阅读:
    JVM内存模型介绍
    C#Winform自定义信息提示框控件
    java pdf转word 支持图片转换到word(最大程度的解决原PDF)
    MySQL 锁分类和详细介绍
    selenium 文件上传方法
    上班干,下班学!这份 Java 面试八股文涵盖 20 多个技术点,还有优质面经分享,别再说卷不过别人了~
    Q41F-40C手动球阀型号解析
    深入理解注意力机制(下)——缩放点积注意力及示例
    双侧检验Two-Tailed Test
    计算机毕业设计Java健身俱乐部业务关系系统(源码+系统+mysql数据库+lw文档)
  • 原文地址:https://blog.csdn.net/athrunsunny/article/details/133222379