• YOLOv5代码解读[03] utils/loss.py文件解析


    import math
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from utils.metrics import bbox_iou
    from utils.torch_utils import de_parallel
    from utils.general import xywh2xyxy, box_iou
    
    
    # 标签平滑
    def smooth_BCE(eps=0.1): 
        return 1.0 - 0.5*eps, 0.5*eps
    
    
    # 主要是为了减少false negatives(认为标签中没有,但实际中有)的影响。COCO数据集中有些目标的标签缺失。
    class BCEBlurWithLogitsLoss(nn.Module):
        def __init__(self, alpha=0.05):
            super().__init__()
            self.loss_fcn = nn.BCEWithLogitsLoss(reduction='none')  
            self.alpha = alpha
    
        def forward(self, pred, true):
            loss = self.loss_fcn(pred, true)
            pred = torch.sigmoid(pred)  # prob from logits
            dx = pred - true  # reduce only missing label effects
            # dx = (pred - true).abs()  # reduce missing label and false label effects
            alpha_factor = 1 - torch.exp((dx - 1) / (self.alpha + 1e-4))
            loss *= alpha_factor
            return loss.mean()
    
    # Focal loss主要是为了解决one-stage目标检测中正负样本比例严重失衡的问题。
    # 该损失函数降低了大量简单负样本在训练中所占的权重。
    class FocalLoss(nn.Module):
        # Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
        def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
            super().__init__()
            self.loss_fcn = loss_fcn  # must be nn.BCEWithLogitsLoss()
            self.gamma = gamma
            self.alpha = alpha
            self.reduction = loss_fcn.reduction
            self.loss_fcn.reduction = 'none'  # required to apply FL to each element
    
        def forward(self, pred, true):
            loss = self.loss_fcn(pred, true)
            # p_t = torch.exp(-loss)
            # loss *= self.alpha * (1.000001 - p_t) ** self.gamma  # non-zero power for gradient stability
    
            # TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
            pred_prob = torch.sigmoid(pred)  # prob from logits
            p_t = true * pred_prob + (1 - true) * (1 - pred_prob)
            alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
            modulating_factor = (1.0 - p_t) ** self.gamma
            loss *= alpha_factor * modulating_factor
    
            if self.reduction == 'mean':
                return loss.mean()
            elif self.reduction == 'sum':
                return loss.sum()
            else:  # 'none'
                return loss
    
    
    class QFocalLoss(nn.Module):
        # Wraps Quality focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
        def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
            super().__init__()
            self.loss_fcn = loss_fcn  # must be nn.BCEWithLogitsLoss()
            self.gamma = gamma
            self.alpha = alpha
            self.reduction = loss_fcn.reduction
            self.loss_fcn.reduction = 'none'  # required to apply FL to each element
    
        def forward(self, pred, true):
            loss = self.loss_fcn(pred, true)
    
            pred_prob = torch.sigmoid(pred)  # prob from logits
            alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
            modulating_factor = torch.abs(true - pred_prob) ** self.gamma
            loss *= alpha_factor * modulating_factor
    
            if self.reduction == 'mean':
                return loss.mean()
            elif self.reduction == 'sum':
                return loss.sum()
            else:  # 'none'
                return loss
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86

    YOLOv5损失函数

    # 总的损失:分类损失cls + 置信度损失obj + 框坐标回归损失loc
    class ComputeLoss:
        def __init__(self, model, cls_balance, autobalance=False, version=2):
            self.cls_balance = cls_balance
            self.sort_obj_iou = False
            device = next(model.parameters()).device  
            # 取出Detect()模块
            if version == 1:
                det = de_parallel(model).model[-1] 
            elif version == 2:
                det = de_parallel(model).detection 
            else:
                pass
            h = model.hyp  
            
            # 定义类别和目标性得分损失函数
            # nn.BCEWithLogitsLoss()等价于 torch.sigmoid() + torch.nn.BCELoss()
            # cls_pw: classes positive_weight 正样本权重
            # obj_pw: objectness positive_weight 是否为目标的权重
            BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['cls_pw']], device=device), reduction='none')
            BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['obj_pw']], device=device), reduction='mean')
    
            # 类标签平滑技术Class label smoothing https://arxiv.org/pdf/1902.04103.pdf
            self.cp, self.cn = smooth_BCE(eps=h.get('label_smoothing', 0.0))  
    
            # Focal loss
            # 如果设置了fl_gamma参数,就是用focal loss,默认没有使用。
            g = h['fl_gamma']  
            if g > 0:
                BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
                
            # P3-P7 平衡不同尺度损失
            # 设置三个特征图对应输出的损失系数
            self.balance = {3: [4.0, 1.0, 0.4]}.get(det.nl, [4.0, 1.0, 0.25, 0.06, 0.02]) 
        
            # stride 16 index 
            self.ssi = list(det.stride).index(16) if autobalance else 0  
            self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, 1.0, h, autobalance
            
            for k in 'na', 'nc', 'nl', 'anchors':
                setattr(self, k, getattr(det, k))
    
        def __call__(self, p, targets):  
            # 网络输出predictions
            # p[0]-->[bs, 3, 80, 80, nc+5]; 
            # p[1]-->[bs, 3, 40, 40, nc+5]; 
            # p[2]-->[bs, 3, 20, 20, nc+5]
            # 真实标签targets
            # [targets_num, 6],i表示第几张图片;c为类别;xywh为坐标
            
            # 获取设备
            device = targets.device
            # 初始化各个部分损失
            # Objectness loss置信度损失; Classes loss分类损失; Location loss框坐标回归损失
            lobj, lcls, lbox = torch.zeros(1, device=device), torch.zeros(1, device=device), torch.zeros(1, device=device)
            # 获取标签分类,边框,索引,anchors
            # tcls: [M1, ], [M2, ], [M3, ]    表示不同尺度上的预测框对应的真实类别标签;
            # tbox: [M1, 4], [M2, 4], [M3, 4] 表示不同尺度上的预测框对应的真实坐标;
            # indices: 
            tcls, tbox, indices, anchors = self.build_targets(p, targets)  
          
            # 遍历预测的三个特征图的输出
            for i, pi in enumerate(p):  
                # 根据indices获取索引,方便找到对应网格的输出
                # image索引, anchor索引, gridy, gridx
                b, a, gj, gi = indices[i]
                  
                tobj = torch.zeros_like(pi[..., 0], device=device) 
    
                # 预测框中正样本的个数
                n = b.shape[0]  
                # 正样本个数不为0
                if n:
                    # 找到对应网格的输出,取出对应位置的预测值
                    # 第几张图片的第几个anchor中的什么位置
                    # pi[bs, 3, 80, 80, nc+5]
                    ps = pi[b, a, gj, gi] 
                    
                    # =============框坐标回归Regression=============
                    # 对输出xywh做反算
                    # 这块pxy本质上还是偏移量,为什么可以用来计算iou呢?因为两个的基准是一样的,都是基于某一个网格来说!!!
                    pxy = ps[:, :2].sigmoid() * 2 - 0.5
                    # 注意这块的wh, 是目标的真实大小在每个尺度下的映射; 真实大小/8,真实大小/16,真实大小/32;
                    pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]
                    # predicted box
                    pbox = torch.cat((pxy, pwh), 1)  
                    
                    #=======================原始IOU========================
                    # 计算边框损失,注意这个CIou=True, 计算的是ciou损失
                    iou = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, CIoU=True)  
                    # 维度为n
                    lbox += (1.0 - iou).mean()  
                    
                    # ====================IOU各种变形=======================
                    # iou = bbox_iou(pbox, tbox[i], CIoU=True) 
                    # if type(iou) is tuple:
                    #     lbox += (iou[1].detach().squeeze() * (1 - iou[0].squeeze())).mean()
                    #     iou = iou[0].squeeze()
                    # else:
                    #     lbox += (1.0 - iou.squeeze()).mean()  # iou loss
                    #     print(lbox)
                    #     iou = iou.squeeze()
                        
                    # =============置信度Objectness=============
                    # 根据model.gr设置objectness的标签值;有目标的conf分支权重;
                    # 不同anchor和gt bbox匹配度不一样,预测框和gt bbox的匹配度也不一样,如果权重设置一样肯定不是最优的,
                    # 故将预测框和真实框bbox的iou作为权重乘到conf分支,用于表征预测质量。
                    score_iou = iou.detach().clamp(0).type(tobj.dtype)
                    if self.sort_obj_iou:
                        sort_id = torch.argsort(score_iou)
                        b, a, gj, gi, score_iou = b[sort_id], a[sort_id], gj[sort_id], gi[sort_id], score_iou[sort_id]
                    # 这块有点新颖,将预测框和gt的iou作为置信度真实标签的参考值
                    tobj[b, a, gj, gi] = (1.0 - self.gr) + self.gr * score_iou  
    
                    # =============类别Classification=============
                    # 如果类别数大于1,才计算分类损失
                    if self.nc > 1:  
                        t = torch.full_like(ps[:, 5:], self.cn, device=device) 
                        # 设置每个类的标签
                        t[range(n), tcls[i]] = self.cp
                        # 添加类别均衡代码
                        count_balance = [math.pow(it, 1/5) for it in self.cls_balance]
                        cls_weights = [max(count_balance)/it if it !=0 else 1 for it in count_balance]
                        # BCE, 每个类单独计算loss
                        lcls += torch.mean(self.BCEcls(ps[:, 5:], t)*torch.Tensor(cls_weights).cuda()) 
                        # 另外一个行之有效的方法是:增大lcls系数权重,模型将更加关注分类效果!
                        #lcls += torch.mean(self.BCEcls(ps[:, 5:], t))  
                   
                # 这块主要是为了解决添加背景图片,做对抗训练。
                # 计算objectness的损失
                obji = self.BCEobj(pi[..., 4], tobj)
                lobj += obji * self.balance[i]  
                
                if self.autobalance:
                    self.balance[i] = self.balance[i] * 0.9999 + 0.0001 / obji.detach().item()
    
            if self.autobalance:
                self.balance = [x / self.balance[self.ssi] for x in self.balance]
            
            # 根据超参数设置的各个部分损失的系数获得最终损失
            lbox *= self.hyp['box']
            lobj *= self.hyp['obj']
            lcls *= self.hyp['cls']
            # batch size 大小
            bs = tobj.shape[0]  
            return (lbox+lobj+lcls)*bs, torch.cat((lbox, lobj, lcls)).detach()
        """
        匹配正样本
        build_targets函数用于获得在训练时计算loss函数所需要的目标框,即被认为是正样本。
        与yolov3/v4的不同: yolov5支持跨网格预测。
        对于任何一个bbox, 三个输出预测特征层都可能有先验框anchors与之匹配;
        该函数输出的正样本框比传入的targets(GT框)数目多;
        具体处理过程:
        (1) 对于任何一个尺度的特征图,计算当前bbox和当前层anchor的匹配程度,不采用iou,而是shape比例;
            如果anchor和bbox的宽高比差距大于4,则认为不匹配,此时忽略相应的bbox,即当做背景。
        (2) 然后对bbox落在的网格所有anchors都计算loss;
            注意: 此时落在的网格不再是一个而是附近的多个,这样就增加了正样本数,可能存在有些bbox在三个尺度都预测的情况;
            另外,yolov5也没有conf分支忽略阈值(ignore_thresh)的操作,而yolov3/v4有。
        """
        def build_targets(self, p, targets):
            """
            Args:
                p: 网络输出, List[torch.tensor * 3]; p[i].shape = (b,3,h,w,nc+5) h,w分别为特征图的长宽, b为batch-size。
                targets: GT框; targets.shape = (nt, 6), 6=i,c,x,y,w,h, i表示第i+1张图片, c为类别, 坐标xywh。
            Returns:
            """
            # targets(image, class, x, y, w, h)
            # image表示图片在当前batch的id号,class表示类别id,后面依次是归一化了的gt框的x,y,w,h坐标。
              
            # na表示每个尺度特征图的anchor数量,这里为3。
            # nt表示一个batch-size中target数量。
            na, nt = self.na, targets.shape[0]  
            
            tcls, tbox, indices, anch = [], [], [], []
    
            gain = torch.ones(7, device=targets.device)  
            # ai-->(na, nt) 生成anchor索引
            # anchor索引,后面有用,用于表示当前bbox和当前层的哪个anchor匹配。
            ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt) 
            # 先repeat targets和当前层anchor个数一样,相当于每个bbox变成了3个,然后和3个anchor单独匹配。
            # targets [3, nt, 6]--->[3, nt, 7],增加anchor indices
            targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2)  
    
            # 设置网络中心偏移量
            g = 0.5  
            # off-->(5,2) 附近的4个网格(上下左右)
            # [0, 0], [0.5, 0], [0, 0.5], [0, -0.5], [-0.5, 0]
            off = torch.tensor([[0, 0],
                                [1, 0], [0, 1], [-1, 0], [0, -1],      # j,k,l,m
                                # [1, 1], [1, -1], [-1, 1], [-1, -1],  # jk,jm,lk,lm 斜上,斜下,斜左,斜右
                                ], device=targets.device).float() * g  # offsets
    
            # 对每个特征图进行操作,顺序为降采样8-16-32 
            # 三个尺度的预测特征图输出分支
            for i in range(self.nl):
                # 获取该层特征图中的anchors(已经除以了当前特征图对应的stride)
                anchors, shape = self.anchors[i], p[i].shape
                """
                p[i].shape = (b, 3, h, w, nc+5)
                gain = [1, 1, w, h, w, h, 1]
                """
                # p是网络输出值
                gain[2:6] = torch.tensor(p[i].shape)[[3, 2, 3, 2]]  
               
                # 将真实框targets与锚框anchors进行匹配
                # 将标签框的xywh从基于0-1映射到基于特征图;targets的xywh本身是归一化尺度,故需要变成特征图尺度。
                t = targets * gain
                if nt:
                    """
                    真实的wh与anchor的wh做匹配,筛选掉比值大于hyp["anchor_t"]的,从而更好地回归。
                    作者采用新的wh回归方式: (wh.sigmoid() * 2) ** 2 * anchors[i]
                    原来yolov3/v4为anchors[i] * exp(wh)。
                    将标签框与anchor的倍数控制在0~4之间; hyp.scratch.yaml中的超参数anchor_t=4,用于判定anchors与标签框契合度;
                    """
                    # 计算当前target的wh和anchor的wh比例值
                    # 如果最大比例大于预设值model.hyp["anchor_t"]=4,则当前target和anchor匹配度不高,不强制回归,而把target丢弃;
                    # 计算wh比值ratio, 不考虑xy坐标 
                    # t[:, :, 4:6] --> [3, nt, 2] 
                    # anchors[:, None] --> [3, 1, 2]
                    r = t[:, :, 4:6] / anchors[:, None]  
                    # 筛选满足 1/hyp["anchor_t"] < target_wh/anchor_wh < hyp["anchor_t"]的框;
                    # j.shape = (3, nt) = (na, nt)
                    j = torch.max(r, 1/r).max(2)[0] < self.hyp['anchor_t']  
                    # j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t']  # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2))
    
                    # 筛选过后的t.shape=(M, 7), M为筛选过后的数量;
                    # 注意: 过滤规则没有考虑xy, 也就是当前bbox的wh和所有anchors计算的。
                    t = t[j]
                  
                    # Offsets
                    # 获取选择完成的box的中心点左边-gxy(以特征图左上角为坐标原点),并转换为以特征图右下角为坐标原点的坐标-gxi。
                    gxy = t[:, 2:4]  
                    gxi = gain[[2, 3]] - gxy  
                    
                    """
                    把相对于各个网格左上角x<0.5或y<0.5和相对于右下角的x<0.5或y<0.5的框提取出来,也就是j,k,l,m;
                    在选取gij(也就是标签框分配给的网格)的时候,对这四个部分的框都做一个偏移(减去上面的offesets);
                    也就是下面的gij = (gxy - offsets).long操作;
                    再将这四个部分的框跟原始的gij拼接在一起,总共就是五个部分;
                    yolov3/v4仅仅采用当前网格的anchor进行回归; yolov4也有解决网格跑偏的措施,即通过sigmoid限制输出;
                    yolov5中心点回归从yolov3/v4的0~1范围变成-0.5~1.5的范围;
                    中心点回归的公式变为:
                    xy.sigmoid()*2.0 - 0.5 + cx (其中对原始中心点网格坐标扩展了两个邻居像素)
                    """
                    # 对于筛选后的bbox,计算其落在哪个网格内,同时找出邻近的网格,将这些网格都认为是负责预测该bbox的网格;
                    # 浮点数取模的数学含义:对于两个浮点数a和b, a % b = a - n * b, 其中n为不超过a/b的最大整数。
                    # ((gxy % 1 < g) & (gxy > 1)).T的shape为(2,M)
                    # j,k,l,m的shape均为(M, )
                    j, k = ((gxy % 1 < g) & (gxy > 1)).T
                    l, m = ((gxi % 1 < g) & (gxi > 1)).T
                    # j.shape (5, M)
                    j = torch.stack((torch.ones_like(j), j, k, l, m))
                    # 5是因为预设的off是5个
                    # 接近 (M*3, 7)
                    t = t.repeat((5, 1, 1))[j]
                    # 添加偏移量 (1, M, 2) + (5, 1, 2) = (5, M, 2) -->  接近(M*3, 2)
                    offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]
                else:
                    t = targets[0]
                    offsets = 0
                """
                对每个bbox找出对应的正样本anchor, 其中:
                b表示当前bbox属于batch内部的第几张图片;
                c是该bbox类别;
                gxy是对应bbox的中心点坐标xy;
                gwh是对应bbox的wh;
                a表示当前bbox和当前层的第几个anchor匹配上;
                gi, gj是对应的负责预测该bbox的网格坐标;
                """
                # 获取每个box的图像索引和类别
                b, c = t[:, :2].long().T  
                # 中心点回归标签
                gxy = t[:, 2:4] 
                # 宽高回归标签
                gwh = t[:, 4:6]  
                # 当前label落在哪个网格上面
                gij = (gxy - offsets).long()
                gi, gj = gij.T  # grid xy indices(索引值)
                
                # a为anchor索引
                a = t[:, 6].long() 
                # 待分类的类别class
                tcls.append(c)  
                # 待回归的box坐标值
                tbox.append(torch.cat((gxy-gij, gwh), 1))  
                # 添加索引,方便计算损失的时候取出对应位置的输出;
                # torch.clamp详解(限制张量取值范围)
                indices.append((b, a, gj.clamp_(0, shape[2] - 1), gi.clamp_(0, shape[3] - 1)))  
                # anchor尺寸大小
                anch.append(anchors[a])  
            return tcls, tbox, indices, anch
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
    • 168
    • 169
    • 170
    • 171
    • 172
    • 173
    • 174
    • 175
    • 176
    • 177
    • 178
    • 179
    • 180
    • 181
    • 182
    • 183
    • 184
    • 185
    • 186
    • 187
    • 188
    • 189
    • 190
    • 191
    • 192
    • 193
    • 194
    • 195
    • 196
    • 197
    • 198
    • 199
    • 200
    • 201
    • 202
    • 203
    • 204
    • 205
    • 206
    • 207
    • 208
    • 209
    • 210
    • 211
    • 212
    • 213
    • 214
    • 215
    • 216
    • 217
    • 218
    • 219
    • 220
    • 221
    • 222
    • 223
    • 224
    • 225
    • 226
    • 227
    • 228
    • 229
    • 230
    • 231
    • 232
    • 233
    • 234
    • 235
    • 236
    • 237
    • 238
    • 239
    • 240
    • 241
    • 242
    • 243
    • 244
    • 245
    • 246
    • 247
    • 248
    • 249
    • 250
    • 251
    • 252
    • 253
    • 254
    • 255
    • 256
    • 257
    • 258
    • 259
    • 260
    • 261
    • 262
    • 263
    • 264
    • 265
    • 266
    • 267
    • 268
    • 269
    • 270
    • 271
    • 272
    • 273
    • 274
    • 275
    • 276
    • 277
    • 278
    • 279
    • 280
    • 281
    • 282
    • 283
    • 284
    • 285
    • 286
    • 287
    • 288
    • 289
    • 290
    • 291
    • 292

    YOLOX损失函数

    # 总的损失total:分类损失cls + 置信度损失obj + 框坐标回归损失loc
    class ComputeLossOTA:
        def __init__(self, model, autobalance=False, version=2):
            super(ComputeLossOTA, self).__init__()
            # 获取device
            device = next(model.parameters()).device  
            # 超参数hyperparameters
            h = model.hyp  
    
            # 定义类别和目标性得分损失函数
            # nn.BCEWithLogitsLoss()等价于 torch.sigmoid() + torch.nn.BCELoss()
            # cls_pw: classes positive_weight 正样本权重
            # obj_pw: objectness positive_weight 是否为目标的权重
            BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['cls_pw']], device=device))
            BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['obj_pw']], device=device))
    
            # 类标签平滑技术Class label_smoothing https://arxiv.org/pdf/1902.04103.pdf
            self.cp, self.cn = smooth_BCE(eps=h.get('label_smoothing', 0.0))  
    
            # Focal loss
            # 如果设置了fl_gamma参数,就使用focal loss,默认没有使用。
            g = h['fl_gamma']  
            if g > 0:
                BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
    
            # 检测头det
            # 取出Detect()模块
            if version == 1:
                det = de_parallel(model).model[-1] 
            elif version == 2:
                det = de_parallel(model).detection 
            else:
                pass
        
            # P3-P7 平衡不同尺度损失
            # 设置三个特征图对应输出的损失系数
            self.balance = {3: [4.0, 1.0, 0.4]}.get(det.nl, [4.0, 1.0, 0.25, 0.06, .02])  
            # stride 16 index
            self.ssi = list(det.stride).index(16) if autobalance else 0  
            self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, 1.0, h, autobalance
            for k in 'na', 'nc', 'nl', 'anchors', 'stride':
                setattr(self, k, getattr(det, k))
    
        def __call__(self, p, targets, imgs):  
            # 网络输出predictions
            # p[0]-->[bs, 3, 80, 80, nc+5]; 
            # p[1]-->[bs, 3, 40, 40, nc+5]; 
            # p[2]-->[bs, 3, 20, 20, nc+5];
    
            # 真实标签targets
            # [targets_num, 6], i表示第几张图片; c为类别; xywh为坐标.
            
            # 获取设备
            device = targets.device
            # 初始化各个部分损失
            # Objectness loss置信度损失; Classes loss分类损失; Location loss框坐标回归损失
            lcls, lbox, lobj = torch.zeros(1, device=device), torch.zeros(1, device=device), torch.zeros(1, device=device)
            # 正负样本的匹配
            bs, as_, gjs, gis, targets, anchors = self.build_targets(p, targets, imgs)
            pre_gen_gains = [torch.tensor(pp.shape, device=device)[[3, 2, 3, 2]] for pp in p] 
        
            # 遍历预测的三个特征图的输出
            for i, pi in enumerate(p): 
                # 根据indices获取索引,方便找到对应网格的输出
                # image索引, anchor索引, gridy, gridx
                b, a, gj, gi = bs[i], as_[i], gjs[i], gis[i]  
                
                # target obj
                tobj = torch.zeros_like(pi[..., 0], device=device)  
                # 预测框中正样本的个数
                n = b.shape[0]  
                # 如果正样本个数不为0
                if n:
                    # 找到对应网格的输出,取出对应位置的预测值
                    # 第几张图片的第几个anchor中的什么位置
                    # pi[bs, 3, 80, 80, nc+5]
                    # [pi_n, nc+5]  pi_n表示每个特征图的正样本个数
                    ps = pi[b, a, gj, gi]  
    
                    # =============框坐标回归Regression=============
                    grid = torch.stack([gi, gj], dim=1)
                    # 对输出xywh做反算
                    # 这块pxy本质上还是偏移量,为什么可以用来计算iou呢?因为两个的基准是一样的,都是基于某一个网格来说!!!
                    pxy = ps[:, :2].sigmoid() * 2. - 0.5
                    #pxy = ps[:, :2].sigmoid() * 3. - 1.
                    pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]
                    # predicted box
                    pbox = torch.cat((pxy, pwh), 1)  
                    # 将gt框也换算到特征图尺度上
                    selected_tbox = targets[i][:, 2:6] * pre_gen_gains[i]
                    selected_tbox[:, :2] -= grid
                    # 计算边框损失,注意这个CIou=True, 计算的是ciou损失
                    # [pi_n] 
                    iou = bbox_iou(pbox.T, selected_tbox, x1y1x2y2=False, CIoU=True) 
                    # 维度为n
                    lbox += (1.0 - iou).mean()  
    
                    # =============置信度Objectness=============
                    # 根据model.gr设置objectness的标签值;
                    # 不同anchor和gt bbox匹配度不一样,预测框和gt bbox的匹配度也不一样,如果权重设置一样肯定不是最优的,
                    # 故将预测框和真实框bbox的iou作为权重乘到obj分支,用于表征预测质量。
                    # 这块有点新颖,将预测框和gt的iou作为置信度真实标签的一个分支来参考。
                    tobj[b, a, gj, gi] = (1.0 - self.gr) + self.gr * iou.detach().clamp(0).type(tobj.dtype)  
    
                    # =============类别Classification=============
                    # gt类别标签
                    selected_tcls = targets[i][:, 1].long()
                    # 如果类别数大于1,才计算分类损失
                    if self.nc > 1:  
                        # [pi_n, 2]  pi_n表示每个特征图的正样本个数
                        t = torch.full_like(ps[:, 5:], self.cn, device=device)  
                        # 设置每个类的标签
                        t[range(n), selected_tcls] = self.cp
                        # BCE, 每个类单独计算loss
                        lcls += self.BCEcls(ps[:, 5:], t) 
    
                # 这块主要是为了解决添加背景图片,做对抗训练。
                # 计算objectness的损失
                obji = self.BCEobj(pi[..., 4], tobj)
                # 这个尺度的损失权重
                lobj += obji * self.balance[i]  
                
                if self.autobalance:
                    self.balance[i] = self.balance[i] * 0.9999 + 0.0001 / obji.detach().item()
    
            if self.autobalance:
                self.balance = [x / self.balance[self.ssi] for x in self.balance]
            
            # 根据超参数设置的各个部分损失的系数获得最终损失
            lbox *= self.hyp['box']
            lobj *= self.hyp['obj']
            lcls *= self.hyp['cls']
            # batch size大小
            bs = tobj.shape[0] 
            # 总的损失loss
            return (lbox + lobj + lcls) * bs, torch.cat((lbox, lobj, lcls)).detach()
           
        def build_targets(self, p, targets, imgs):
            """
            网络输出predictions
            p[0]-->[bs, 3, 80, 80, nc+5]; 
            p[1]-->[bs, 3, 40, 40, nc+5]; 
            p[2]-->[bs, 3, 20, 20, nc+5] 
    
            真实标签targets
            [targets_num, 6], (index,c,x,y,w,h) index表示第几张图片; c为类别; xywh为坐标
            """
            # 1.YOLOv5正负样本分配策略
            indices, anch = self.find_3_positive(p, targets)
        
            # 2.YOLOX正负样本分配策略
            matching_bs = [[] for pp in p]
            matching_as = [[] for pp in p]
            matching_gjs = [[] for pp in p]
            matching_gis = [[] for pp in p]
            matching_targets = [[] for pp in p]
            matching_anchs = [[] for pp in p]
            
            # 检测层的输出数量(不同尺度个数)
            nl = len(p)    
            
            # 对于每一张图片来说
            for batch_idx in range(p[0].shape[0]):
                # 找到batch_idx的gt目标
                b_idx = targets[:, 0]==batch_idx
                # (p_nt, 6) 每张图片对应的targets的数量
                this_target = targets[b_idx]
                if this_target.shape[0] == 0:
                    continue
                # ?????????????????????
                # this_target*wh,还原GT框, shape=(p_nt, 4)
                txywh = this_target[:, 2:6] * imgs[batch_idx].shape[1]
                # 格式转换成(x1,y1,x2,y2), shape=(p_nt, 4)
                txyxy = xywh2xyxy(txywh)
    
                pxyxys = []
                p_obj = []
                p_cls = []
                from_which_layer = []
                all_b = []
                all_a = []
                all_gj = []
                all_gi = []
                all_anch = []
                
                # 对于每一个尺度的输出来说
                for i, pi in enumerate(p):
                    b, a, gj, gi = indices[i]
                    # 把当前图片对应的该层特征图的正样本取出来
                    idx = (b == batch_idx)
                    b, a, gj, gi = b[idx], a[idx], gj[idx], gi[idx]                
                    all_b.append(b)
                    all_a.append(a)
                    all_gj.append(gj)
                    all_gi.append(gi)
                    all_anch.append(anch[i][idx])
                    from_which_layer.append(torch.ones(size=(len(b),)) * i)
                    
                    # 当前图片对应的该层特征图的正样本预测preds结果
                    # (p_i_pred, 4+1+nc)
                    fg_pred = pi[b, a, gj, gi]   
                    # obj预测值
                    p_obj.append(fg_pred[:, 4:5]) 
                    # cls预测值
                    p_cls.append(fg_pred[:, 5:])
                    
                    # (p_i_pred, 2)
                    grid = torch.stack([gi, gj], dim=1)
                    # 解码阶段
                    pxy = (fg_pred[:, :2].sigmoid() * 2. - 0.5 + grid) * self.stride[i] #/ 8.
                    #pxy = (fg_pred[:, :2].sigmoid() * 3. - 1. + grid) * self.stride[i]
                    pwh = (fg_pred[:, 2:4].sigmoid() * 2) ** 2 * anch[i][idx] * self.stride[i] #/ 8.
                    pxywh = torch.cat([pxy, pwh], dim=-1)
                    # (p_i_pred, 4)
                    pxyxy = xywh2xyxy(pxywh)
                    pxyxys.append(pxyxy)
                
                # 把当前图片对应的三层特征图的正样本取出来 
                pxyxys = torch.cat(pxyxys, dim=0)
                if pxyxys.shape[0] == 0:
                    continue
                p_obj = torch.cat(p_obj, dim=0)
                p_cls = torch.cat(p_cls, dim=0)
                from_which_layer = torch.cat(from_which_layer, dim=0)
                all_b = torch.cat(all_b, dim=0)
                all_a = torch.cat(all_a, dim=0)
                all_gj = torch.cat(all_gj, dim=0)
                all_gi = torch.cat(all_gi, dim=0)
                all_anch = torch.cat(all_anch, dim=0)
                
                # =======================a.Loss函数计算========================
                # 1. 回归IOU损失
                # 计算GT与预测正样本框的IOU矩阵, (p_nt, p_pred)
                pair_wise_iou = box_iou(txyxy, pxyxys)
                # 计算GT与预测正样本框的IOU矩阵loss
                pair_wise_iou_loss = -torch.log(pair_wise_iou + 1e-8)
                
                # 2. 分类*置信度损失
                # 对每一个GT的label做one hot,然后重复p_pred次,shape=(p_nt, p_pred, 85)
                # gt_classes:tensor([0., 0., 0.], device='cuda:0')--->torch.Size([3])
                # F.one_hot(gt_classes.to(torch.int64), self.num_classes): 
                # tensor([[1, 0, 0, 0],
                #         [1, 0, 0, 0],
                #         [1, 0, 0, 0]], device='cuda:0') --->torch.Size([3, 4])
                # F.one_hot(gt_classes.to(torch.int64), self.num_classes).float().unsqueeze(1)
                # tensor([[[1., 0., 0., 0.]],
                #         [[1., 0., 0., 0.]],
                #         [[1., 0., 0., 0.]]], device='cuda:0') --->torch.Size([3, 1, 4])
                # gt_cls_per_image: (p_nt, p_pred, nc) 
                gt_cls_per_image = (
                    F.one_hot(this_target[:, 1].to(torch.int64), self.nc)
                    .float()
                    .unsqueeze(1)
                    .repeat(1, pxyxys.shape[0], 1)
                )
                
                # 当前图片有多少个gtbox目标
                num_gt = this_target.shape[0]
                # (p_gt, p_pred, nc) * (p_gt, p_pred, 1) 置信度得分*类别得分
                cls_preds_ = (
                    p_cls.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
                    * p_obj.unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
                )
                
                # (p_gt, p_pred) 分类*置信度损失
                y = cls_preds_.sqrt_()
                pair_wise_cls_loss = F.binary_cross_entropy_with_logits(
                   torch.log(y/(1-y)), gt_cls_per_image, reduction="none").sum(-1)
                del cls_preds_
            
                # =======================b.cost成本计算========================
                # (p_gt, p_pred)
                cost = (pair_wise_cls_loss + 3.0 * pair_wise_iou_loss)
    
                # =======================c.SimOTA求解========================
                # &&&&&&&&&&&&&&&&&& 第一步:设置候选框数量 &&&&&&&&&&&&&&&&&&&&&&
                # 首先按照cost值的大小,新建一个全0变量matching_matrix,这里是[p_gt, p_pred]。
                matching_matrix = torch.zeros_like(cost)
                # torch.topk默认从大到小进行排序,找到前k个数, shape=(p_nt, 10) 即每个gtbox都取自己排名前10的IOU
                top_k, _ = torch.topk(pair_wise_iou, min(10, pair_wise_iou.shape[1]), dim=1)
                # top_k.sum(1).int()对前10个iou相加并取整,每一个表示GT需要取dynamic_ks的正样本,shape=[p_nt],
                # clamp是区间函数,每一个目标保证必须有一个正样本,因此不能小于1。
                dynamic_ks = torch.clamp(top_k.sum(1).int(), min=1)
                
                # &&&&&&&&&&&&&&&&&& 第二步:通过cost挑选候选框 &&&&&&&&&&&&&&&&&&
                # 对于每一个gtbox来说, 相应的cost值最低的一些候选框
                for gt_idx in range(num_gt):
                    # 取cost[gt_idx]矩阵的前dynamic_k个位置下标.
                    _, pos_idx = torch.topk(cost[gt_idx], k=dynamic_ks[gt_idx].item(), largest=False)
                    # 将前dynamic_k个位置下标置1.
                    matching_matrix[gt_idx][pos_idx] = 1.0
                del top_k, dynamic_ks
    
                # &&&&&&&&&&&&&&&&&& 第三步:过滤共用的候选框 &&&&&&&&&&&&&&&&&&
                # (p_pred, )
                anchor_matching_gt = matching_matrix.sum(0)
                # 在anchor_matching_gt中,只要有大于1的,说明有共用的情况。
                if (anchor_matching_gt > 1).sum() > 0:
                    # cost_argmin是针对于共用列来说,取出与它cost最小的那个(p_nt中某一个)的行索引。
                    _, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)
                    matching_matrix[:, anchor_matching_gt > 1] *= 0.0
                    matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1.0
                # 前景mask
                fg_mask_inboxes = matching_matrix.sum(0) > 0.0
                # 最合适的候选框分别对应的真实框box
                matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
              
                from_which_layer = from_which_layer[fg_mask_inboxes]
                all_b = all_b[fg_mask_inboxes]
                all_a = all_a[fg_mask_inboxes]
                all_gj = all_gj[fg_mask_inboxes]
                all_gi = all_gi[fg_mask_inboxes]
                all_anch = all_anch[fg_mask_inboxes]
                this_target = this_target[matched_gt_inds]
                
                # 对于每一个尺度的输出来说
                for i in range(nl):
                    layer_idx = from_which_layer == i
                    matching_bs[i].append(all_b[layer_idx])
                    matching_as[i].append(all_a[layer_idx])
                    matching_gjs[i].append(all_gj[layer_idx])
                    matching_gis[i].append(all_gi[layer_idx])
                    matching_anchs[i].append(all_anch[layer_idx])
                    matching_targets[i].append(this_target[layer_idx])
            
            # 对于每一个尺度的输出来说,把一个batch_size的图片拼接起来
            for i in range(nl):
                if matching_targets[i] != []:
                    matching_bs[i] = torch.cat(matching_bs[i], dim=0)
                    matching_as[i] = torch.cat(matching_as[i], dim=0)
                    matching_gjs[i] = torch.cat(matching_gjs[i], dim=0)
                    matching_gis[i] = torch.cat(matching_gis[i], dim=0)
                    matching_targets[i] = torch.cat(matching_targets[i], dim=0)
                    matching_anchs[i] = torch.cat(matching_anchs[i], dim=0)
                else:
                    matching_bs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
                    matching_as[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
                    matching_gjs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
                    matching_gis[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
                    matching_targets[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
                    matching_anchs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
    
            return matching_bs, matching_as, matching_gjs, matching_gis, matching_targets, matching_anchs           
        """
        匹配正样本
        build_targets函数用于获得在训练时计算loss函数所需要的目标框,即被认为是正样本。
        与yolov3/v4的不同: yolov5支持跨网格预测。
        对于任何一个bbox, 三个输出预测特征层都可能有先验框anchors与之匹配;
        该函数输出的正样本框比传入的targets(GT框)数目多;
        具体处理过程:
        (1) 对于任何一个尺度的特征图,计算当前bbox和当前层anchor的匹配程度,不采用iou,而是shape比例;
            如果anchor和bbox的宽高比差距大于4,则认为不匹配,此时忽略相应的bbox,即当做背景。
        (2) 然后对bbox落在的网格所有anchors都计算loss;
            注意:此时落在的网格不再是一个而是附近的多个,这样就增加了正样本数,可能存在有些bbox在三个尺度都预测的情况;
            另外,yolov5也没有conf分支忽略阈值(ignore_thresh)的操作,而yolov3/v4有。
        """
        def find_3_positive(self, p, targets):
            """
            p: 网络输出, p[i](b,3,h,w,nc+5) h,w分别为特征图的长宽, b为batch-size。
            targets: GT框, targets(nt, 6), 6=i,c,x,y,w,h, i表示第i+1张图片, c为类别, 坐标xywh。
            """
            # na表示每个尺度特征图的anchor数量,这里为3。
            # nt表示一个batch-size中target数量。
            na, nt = self.na, targets.shape[0]  
            indices, anch = [], []
            gain = torch.ones(7, device=targets.device).long()  
            # ai-->(na, nt) 生成anchor索引
            # anchor索引,后面有用,用于表示当前bbox和当前层的哪个anchor匹配。
            ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt)  
            # 先repeat targets和当前层anchor个数一样,相当于每个bbox变成了3个,然后和3个anchor单独匹配。
            # targets [3, nt, 6]--->[3, nt, 7],增加anchor indices
            targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2)  
    
            # 设置网络中心偏移量
            g = 0.5  
            # off-->(5,2) 附近的4个网格(上下左右)
            # [0, 0], [0.5, 0], [0, 0.5], [0, -0.5], [-0.5, 0]
            off = torch.tensor([[0, 0],
                                [1, 0], [0, 1], [-1, 0], [0, -1],      # j,k,l,m
                                # [1, 1], [1, -1], [-1, 1], [-1, -1],  # jk,jm,lk,lm 斜上,斜下,斜左,斜右
                                ], device=targets.device).float() * g  # offsets
    
            # 对每个特征图进行操作,顺序为降采样8-16-32 
            # 三个尺度的预测特征图输出分支
            for i in range(self.nl):
                # 获取该层特征图中的anchors(已经除以了当前特征图对应的stride)
                anchors = self.anchors[i]
    
                # xyxy gain
                gain[2:6] = torch.tensor(p[i].shape)[[3, 2, 3, 2]]  
    
                # 将真实框targets与锚框anchors进行匹配。
                # 将标签框的xywh从基于0-1映射到基于特征图;targets的xywh本身是归一化尺度,故需要变成特征图尺度。
                # 真实目标框在当前特征图上的大小
                t = targets * gain
                if nt:
                    # 计算当前target的wh和anchor的wh比例值
                    # 如果最大比例大于预设值model.hyp["anchor_t"]=4,则当前target和anchor匹配度不高,不强制回归,而把target丢弃;
                    # 计算wh比值ratio, 不考虑xy坐标 
                    # t[:, :, 4:6] --> [3, nt, 2]
                    # anchors[:, None] --> [3, 1, 2]
                    r = t[:, :, 4:6] / anchors[:, None]  
                    # 筛选满足 1/hyp["anchor_t"] < target_wh/anchor_wh < hyp["anchor_t"]的框;
                    # j.shape = (3, nt) = (na, nt)
                    j = torch.max(r, 1./r).max(2)[0] < self.hyp['anchor_t']  
                    # j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t']  # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2))
                    
                    # 筛选过后的t.shape=(M, 7), M为筛选过后的数量;
                    # 注意: 过滤规则没有考虑xy, 也就是当前bbox的wh和所有anchors计算的。
                    t = t[j]  
    
                    # Offsets
                    # 获取选择完成的box的中心点左边gxy(以特征图左上角为坐标原点),并转换为以特征图右下角为坐标原点的坐标gxi。
                    gxy = t[:, 2:4]  
                    gxi = gain[[2, 3]] - gxy  
    
                    """
                    把相对于各个网格左上角x<0.5或y<0.5和相对于右下角的x<0.5或y<0.5的框提取出来,也就是j,k,l,m;
                    在选取gij(也就是标签框分配给的网格)的时候,对这四个部分的框都做一个偏移(减去上面的offesets);
                    也就是下面的gij = (gxy - offsets).long操作;
                    再将这四个部分的框跟原始的gij拼接在一起,总共就是五个部分;
                    yolov3/v4仅仅采用当前网格的anchor进行回归; yolov4也有解决网格跑偏的措施,即通过sigmoid限制输出;
                    yolov5中心点回归从yolov3/v4的0~1范围变成-0.5~1.5的范围;
                    中心点回归的公式变为:
                    xy.sigmoid()*2.0 - 0.5 + cx (其中对原始中心点网格坐标扩展了两个邻居像素)
                    """
                    # 对于筛选后的bbox,计算其落在哪个网格内,同时找出邻近的网格,将这些网格都认为是负责预测该bbox的网格;
                    # 浮点数取模的数学含义:对于两个浮点数a和b, a % b = a - n * b, 其中n为不超过a/b的最大整数。
                    # ((gxy % 1 < g) & (gxy > 1)).T的shape为(2,M)
                    # j,k,l,m的shape均为(M, )
                    j, k = ((gxy % 1. < g) & (gxy > 1.)).T
                    l, m = ((gxi % 1. < g) & (gxi > 1.)).T
                    # j.shape (5, M)
                    j = torch.stack((torch.ones_like(j), j, k, l, m))
                    # 5是因为预设的off是5个
                    # 接近 (M*3, 7)
                    t = t.repeat((5, 1, 1))[j]
                    # 添加偏移量 (1, M, 2) + (5, 1, 2) = (5, M, 2) -->  接近(M*3, 2)
                    offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]
                # 这块怎么理解???????????????????
                else:
                    t = targets[0]
                    offsets = 0
    
                """
                对每个bbox找出对应的正样本anchor, 其中:
                b表示当前bbox属于batch内部的第几张图片;
                c是该bbox类别;
                gxy是对应bbox的中心点坐标xy;
                gwh是对应bbox的wh;
                a表示当前bbox和当前层的第几个anchor匹配上;
                gi, gj是对应的负责预测该bbox的网格坐标;
                """
                # 获取每个box的图像索引和类别
                b, c = t[:, :2].long().T  
                # 中心点回归标签
                gxy = t[:, 2:4]  
                # 宽高回归标签
                gwh = t[:, 4:6]  
                # 当前label落在哪个网格上面
                gij = (gxy - offsets).long()
                gi, gj = gij.T  # grid xy indices(索引值)
    
                # a为当前层anchor索引
                a = t[:, 6].long()  
                # 添加索引,方便计算损失的时候取出对应位置的输出;
                # torch.clamp详解(限制张量取值范围)
                indices.append((b, a, gj.clamp_(0, gain[3]-1), gi.clamp_(0, gain[2]-1))) 
                # anchor尺寸大小
                anch.append(anchors[a])  
    
            return indices, anch
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
    • 168
    • 169
    • 170
    • 171
    • 172
    • 173
    • 174
    • 175
    • 176
    • 177
    • 178
    • 179
    • 180
    • 181
    • 182
    • 183
    • 184
    • 185
    • 186
    • 187
    • 188
    • 189
    • 190
    • 191
    • 192
    • 193
    • 194
    • 195
    • 196
    • 197
    • 198
    • 199
    • 200
    • 201
    • 202
    • 203
    • 204
    • 205
    • 206
    • 207
    • 208
    • 209
    • 210
    • 211
    • 212
    • 213
    • 214
    • 215
    • 216
    • 217
    • 218
    • 219
    • 220
    • 221
    • 222
    • 223
    • 224
    • 225
    • 226
    • 227
    • 228
    • 229
    • 230
    • 231
    • 232
    • 233
    • 234
    • 235
    • 236
    • 237
    • 238
    • 239
    • 240
    • 241
    • 242
    • 243
    • 244
    • 245
    • 246
    • 247
    • 248
    • 249
    • 250
    • 251
    • 252
    • 253
    • 254
    • 255
    • 256
    • 257
    • 258
    • 259
    • 260
    • 261
    • 262
    • 263
    • 264
    • 265
    • 266
    • 267
    • 268
    • 269
    • 270
    • 271
    • 272
    • 273
    • 274
    • 275
    • 276
    • 277
    • 278
    • 279
    • 280
    • 281
    • 282
    • 283
    • 284
    • 285
    • 286
    • 287
    • 288
    • 289
    • 290
    • 291
    • 292
    • 293
    • 294
    • 295
    • 296
    • 297
    • 298
    • 299
    • 300
    • 301
    • 302
    • 303
    • 304
    • 305
    • 306
    • 307
    • 308
    • 309
    • 310
    • 311
    • 312
    • 313
    • 314
    • 315
    • 316
    • 317
    • 318
    • 319
    • 320
    • 321
    • 322
    • 323
    • 324
    • 325
    • 326
    • 327
    • 328
    • 329
    • 330
    • 331
    • 332
    • 333
    • 334
    • 335
    • 336
    • 337
    • 338
    • 339
    • 340
    • 341
    • 342
    • 343
    • 344
    • 345
    • 346
    • 347
    • 348
    • 349
    • 350
    • 351
    • 352
    • 353
    • 354
    • 355
    • 356
    • 357
    • 358
    • 359
    • 360
    • 361
    • 362
    • 363
    • 364
    • 365
    • 366
    • 367
    • 368
    • 369
    • 370
    • 371
    • 372
    • 373
    • 374
    • 375
    • 376
    • 377
    • 378
    • 379
    • 380
    • 381
    • 382
    • 383
    • 384
    • 385
    • 386
    • 387
    • 388
    • 389
    • 390
    • 391
    • 392
    • 393
    • 394
    • 395
    • 396
    • 397
    • 398
    • 399
    • 400
    • 401
    • 402
    • 403
    • 404
    • 405
    • 406
    • 407
    • 408
    • 409
    • 410
    • 411
    • 412
    • 413
    • 414
    • 415
    • 416
    • 417
    • 418
    • 419
    • 420
    • 421
    • 422
    • 423
    • 424
    • 425
    • 426
    • 427
    • 428
    • 429
    • 430
    • 431
    • 432
    • 433
    • 434
    • 435
    • 436
    • 437
    • 438
    • 439
    • 440
    • 441
    • 442
    • 443
    • 444
    • 445
    • 446
    • 447
    • 448
    • 449
    • 450
    • 451
    • 452
    • 453
    • 454
    • 455
    • 456
    • 457
    • 458
    • 459
    • 460
    • 461
    • 462
    • 463
    • 464
    • 465
    • 466
    • 467
    • 468
    • 469
    • 470
    • 471
    • 472
    • 473
  • 相关阅读:
    Eureka 相关配置及特性
    下班后根本联系不上,这样的员工可以辞退吗
    apache-poi导出数据到excel(SXSSF)
    vue3 props 传值
    代码随想录11——栈与队列:理论基础、232.用栈实现队列、225.用队列实现栈、20.有效的括号、1047. 删除字符串中的所有相邻重复项
    SpringBoot定时任务 - Spring自带的定时任务是如何实现的?有何注意点?
    CarEye 智能云平台升级
    论文数据去哪找?
    国内外的免费AI作图工具
    用vscode调试ros或ros2的python程序(rclpy)
  • 原文地址:https://blog.csdn.net/weixin_43593330/article/details/136215080