import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils.metrics import bbox_iou
from utils.torch_utils import de_parallel
from utils.general import xywh2xyxy, box_iou
def smooth_BCE(eps=0.1):
return 1.0 - 0.5*eps, 0.5*eps
class BCEBlurWithLogitsLoss(nn.Module):
def __init__(self, alpha=0.05):
super().__init__()
self.loss_fcn = nn.BCEWithLogitsLoss(reduction='none')
self.alpha = alpha
def forward(self, pred, true):
loss = self.loss_fcn(pred, true)
pred = torch.sigmoid(pred)
dx = pred - true
alpha_factor = 1 - torch.exp((dx - 1) / (self.alpha + 1e-4))
loss *= alpha_factor
return loss.mean()
class FocalLoss(nn.Module):
def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
super().__init__()
self.loss_fcn = loss_fcn
self.gamma = gamma
self.alpha = alpha
self.reduction = loss_fcn.reduction
self.loss_fcn.reduction = 'none'
def forward(self, pred, true):
loss = self.loss_fcn(pred, true)
pred_prob = torch.sigmoid(pred)
p_t = true * pred_prob + (1 - true) * (1 - pred_prob)
alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
modulating_factor = (1.0 - p_t) ** self.gamma
loss *= alpha_factor * modulating_factor
if self.reduction == 'mean':
return loss.mean()
elif self.reduction == 'sum':
return loss.sum()
else:
return loss
class QFocalLoss(nn.Module):
def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
super().__init__()
self.loss_fcn = loss_fcn
self.gamma = gamma
self.alpha = alpha
self.reduction = loss_fcn.reduction
self.loss_fcn.reduction = 'none'
def forward(self, pred, true):
loss = self.loss_fcn(pred, true)
pred_prob = torch.sigmoid(pred)
alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
modulating_factor = torch.abs(true - pred_prob) ** self.gamma
loss *= alpha_factor * modulating_factor
if self.reduction == 'mean':
return loss.mean()
elif self.reduction == 'sum':
return loss.sum()
else:
return loss

- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
- 79
- 80
- 81
- 82
- 83
- 84
- 85
- 86
YOLOv5损失函数
class ComputeLoss:
def __init__(self, model, cls_balance, autobalance=False, version=2):
self.cls_balance = cls_balance
self.sort_obj_iou = False
device = next(model.parameters()).device
if version == 1:
det = de_parallel(model).model[-1]
elif version == 2:
det = de_parallel(model).detection
else:
pass
h = model.hyp
BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['cls_pw']], device=device), reduction='none')
BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['obj_pw']], device=device), reduction='mean')
self.cp, self.cn = smooth_BCE(eps=h.get('label_smoothing', 0.0))
g = h['fl_gamma']
if g > 0:
BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
self.balance = {3: [4.0, 1.0, 0.4]}.get(det.nl, [4.0, 1.0, 0.25, 0.06, 0.02])
self.ssi = list(det.stride).index(16) if autobalance else 0
self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, 1.0, h, autobalance
for k in 'na', 'nc', 'nl', 'anchors':
setattr(self, k, getattr(det, k))
def __call__(self, p, targets):
device = targets.device
lobj, lcls, lbox = torch.zeros(1, device=device), torch.zeros(1, device=device), torch.zeros(1, device=device)
tcls, tbox, indices, anchors = self.build_targets(p, targets)
for i, pi in enumerate(p):
b, a, gj, gi = indices[i]
tobj = torch.zeros_like(pi[..., 0], device=device)
n = b.shape[0]
if n:
ps = pi[b, a, gj, gi]
pxy = ps[:, :2].sigmoid() * 2 - 0.5
pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]
pbox = torch.cat((pxy, pwh), 1)
iou = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, CIoU=True)
lbox += (1.0 - iou).mean()
score_iou = iou.detach().clamp(0).type(tobj.dtype)
if self.sort_obj_iou:
sort_id = torch.argsort(score_iou)
b, a, gj, gi, score_iou = b[sort_id], a[sort_id], gj[sort_id], gi[sort_id], score_iou[sort_id]
tobj[b, a, gj, gi] = (1.0 - self.gr) + self.gr * score_iou
if self.nc > 1:
t = torch.full_like(ps[:, 5:], self.cn, device=device)
t[range(n), tcls[i]] = self.cp
count_balance = [math.pow(it, 1/5) for it in self.cls_balance]
cls_weights = [max(count_balance)/it if it !=0 else 1 for it in count_balance]
lcls += torch.mean(self.BCEcls(ps[:, 5:], t)*torch.Tensor(cls_weights).cuda())
obji = self.BCEobj(pi[..., 4], tobj)
lobj += obji * self.balance[i]
if self.autobalance:
self.balance[i] = self.balance[i] * 0.9999 + 0.0001 / obji.detach().item()
if self.autobalance:
self.balance = [x / self.balance[self.ssi] for x in self.balance]
lbox *= self.hyp['box']
lobj *= self.hyp['obj']
lcls *= self.hyp['cls']
bs = tobj.shape[0]
return (lbox+lobj+lcls)*bs, torch.cat((lbox, lobj, lcls)).detach()
"""
匹配正样本
build_targets函数用于获得在训练时计算loss函数所需要的目标框,即被认为是正样本。
与yolov3/v4的不同: yolov5支持跨网格预测。
对于任何一个bbox, 三个输出预测特征层都可能有先验框anchors与之匹配;
该函数输出的正样本框比传入的targets(GT框)数目多;
具体处理过程:
(1) 对于任何一个尺度的特征图,计算当前bbox和当前层anchor的匹配程度,不采用iou,而是shape比例;
如果anchor和bbox的宽高比差距大于4,则认为不匹配,此时忽略相应的bbox,即当做背景。
(2) 然后对bbox落在的网格所有anchors都计算loss;
注意: 此时落在的网格不再是一个而是附近的多个,这样就增加了正样本数,可能存在有些bbox在三个尺度都预测的情况;
另外,yolov5也没有conf分支忽略阈值(ignore_thresh)的操作,而yolov3/v4有。
"""
def build_targets(self, p, targets):
"""
Args:
p: 网络输出, List[torch.tensor * 3]; p[i].shape = (b,3,h,w,nc+5) h,w分别为特征图的长宽, b为batch-size。
targets: GT框; targets.shape = (nt, 6), 6=i,c,x,y,w,h, i表示第i+1张图片, c为类别, 坐标xywh。
Returns:
"""
na, nt = self.na, targets.shape[0]
tcls, tbox, indices, anch = [], [], [], []
gain = torch.ones(7, device=targets.device)
ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt)
targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2)
g = 0.5
off = torch.tensor([[0, 0],
[1, 0], [0, 1], [-1, 0], [0, -1],
], device=targets.device).float() * g
for i in range(self.nl):
anchors, shape = self.anchors[i], p[i].shape
"""
p[i].shape = (b, 3, h, w, nc+5)
gain = [1, 1, w, h, w, h, 1]
"""
gain[2:6] = torch.tensor(p[i].shape)[[3, 2, 3, 2]]
t = targets * gain
if nt:
"""
真实的wh与anchor的wh做匹配,筛选掉比值大于hyp["anchor_t"]的,从而更好地回归。
作者采用新的wh回归方式: (wh.sigmoid() * 2) ** 2 * anchors[i]
原来yolov3/v4为anchors[i] * exp(wh)。
将标签框与anchor的倍数控制在0~4之间; hyp.scratch.yaml中的超参数anchor_t=4,用于判定anchors与标签框契合度;
"""
r = t[:, :, 4:6] / anchors[:, None]
j = torch.max(r, 1/r).max(2)[0] < self.hyp['anchor_t']
t = t[j]
gxy = t[:, 2:4]
gxi = gain[[2, 3]] - gxy
"""
把相对于各个网格左上角x<0.5或y<0.5和相对于右下角的x<0.5或y<0.5的框提取出来,也就是j,k,l,m;
在选取gij(也就是标签框分配给的网格)的时候,对这四个部分的框都做一个偏移(减去上面的offesets);
也就是下面的gij = (gxy - offsets).long操作;
再将这四个部分的框跟原始的gij拼接在一起,总共就是五个部分;
yolov3/v4仅仅采用当前网格的anchor进行回归; yolov4也有解决网格跑偏的措施,即通过sigmoid限制输出;
yolov5中心点回归从yolov3/v4的0~1范围变成-0.5~1.5的范围;
中心点回归的公式变为:
xy.sigmoid()*2.0 - 0.5 + cx (其中对原始中心点网格坐标扩展了两个邻居像素)
"""
j, k = ((gxy % 1 < g) & (gxy > 1)).T
l, m = ((gxi % 1 < g) & (gxi > 1)).T
j = torch.stack((torch.ones_like(j), j, k, l, m))
t = t.repeat((5, 1, 1))[j]
offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]
else:
t = targets[0]
offsets = 0
"""
对每个bbox找出对应的正样本anchor, 其中:
b表示当前bbox属于batch内部的第几张图片;
c是该bbox类别;
gxy是对应bbox的中心点坐标xy;
gwh是对应bbox的wh;
a表示当前bbox和当前层的第几个anchor匹配上;
gi, gj是对应的负责预测该bbox的网格坐标;
"""
b, c = t[:, :2].long().T
gxy = t[:, 2:4]
gwh = t[:, 4:6]
gij = (gxy - offsets).long()
gi, gj = gij.T
a = t[:, 6].long()
tcls.append(c)
tbox.append(torch.cat((gxy-gij, gwh), 1))
indices.append((b, a, gj.clamp_(0, shape[2] - 1), gi.clamp_(0, shape[3] - 1)))
anch.append(anchors[a])
return tcls, tbox, indices, anch

- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
- 79
- 80
- 81
- 82
- 83
- 84
- 85
- 86
- 87
- 88
- 89
- 90
- 91
- 92
- 93
- 94
- 95
- 96
- 97
- 98
- 99
- 100
- 101
- 102
- 103
- 104
- 105
- 106
- 107
- 108
- 109
- 110
- 111
- 112
- 113
- 114
- 115
- 116
- 117
- 118
- 119
- 120
- 121
- 122
- 123
- 124
- 125
- 126
- 127
- 128
- 129
- 130
- 131
- 132
- 133
- 134
- 135
- 136
- 137
- 138
- 139
- 140
- 141
- 142
- 143
- 144
- 145
- 146
- 147
- 148
- 149
- 150
- 151
- 152
- 153
- 154
- 155
- 156
- 157
- 158
- 159
- 160
- 161
- 162
- 163
- 164
- 165
- 166
- 167
- 168
- 169
- 170
- 171
- 172
- 173
- 174
- 175
- 176
- 177
- 178
- 179
- 180
- 181
- 182
- 183
- 184
- 185
- 186
- 187
- 188
- 189
- 190
- 191
- 192
- 193
- 194
- 195
- 196
- 197
- 198
- 199
- 200
- 201
- 202
- 203
- 204
- 205
- 206
- 207
- 208
- 209
- 210
- 211
- 212
- 213
- 214
- 215
- 216
- 217
- 218
- 219
- 220
- 221
- 222
- 223
- 224
- 225
- 226
- 227
- 228
- 229
- 230
- 231
- 232
- 233
- 234
- 235
- 236
- 237
- 238
- 239
- 240
- 241
- 242
- 243
- 244
- 245
- 246
- 247
- 248
- 249
- 250
- 251
- 252
- 253
- 254
- 255
- 256
- 257
- 258
- 259
- 260
- 261
- 262
- 263
- 264
- 265
- 266
- 267
- 268
- 269
- 270
- 271
- 272
- 273
- 274
- 275
- 276
- 277
- 278
- 279
- 280
- 281
- 282
- 283
- 284
- 285
- 286
- 287
- 288
- 289
- 290
- 291
- 292
YOLOX损失函数
class ComputeLossOTA:
def __init__(self, model, autobalance=False, version=2):
super(ComputeLossOTA, self).__init__()
device = next(model.parameters()).device
h = model.hyp
BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['cls_pw']], device=device))
BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['obj_pw']], device=device))
self.cp, self.cn = smooth_BCE(eps=h.get('label_smoothing', 0.0))
g = h['fl_gamma']
if g > 0:
BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
if version == 1:
det = de_parallel(model).model[-1]
elif version == 2:
det = de_parallel(model).detection
else:
pass
self.balance = {3: [4.0, 1.0, 0.4]}.get(det.nl, [4.0, 1.0, 0.25, 0.06, .02])
self.ssi = list(det.stride).index(16) if autobalance else 0
self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, 1.0, h, autobalance
for k in 'na', 'nc', 'nl', 'anchors', 'stride':
setattr(self, k, getattr(det, k))
def __call__(self, p, targets, imgs):
device = targets.device
lcls, lbox, lobj = torch.zeros(1, device=device), torch.zeros(1, device=device), torch.zeros(1, device=device)
bs, as_, gjs, gis, targets, anchors = self.build_targets(p, targets, imgs)
pre_gen_gains = [torch.tensor(pp.shape, device=device)[[3, 2, 3, 2]] for pp in p]
for i, pi in enumerate(p):
b, a, gj, gi = bs[i], as_[i], gjs[i], gis[i]
tobj = torch.zeros_like(pi[..., 0], device=device)
n = b.shape[0]
if n:
ps = pi[b, a, gj, gi]
grid = torch.stack([gi, gj], dim=1)
pxy = ps[:, :2].sigmoid() * 2. - 0.5
pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]
pbox = torch.cat((pxy, pwh), 1)
selected_tbox = targets[i][:, 2:6] * pre_gen_gains[i]
selected_tbox[:, :2] -= grid
iou = bbox_iou(pbox.T, selected_tbox, x1y1x2y2=False, CIoU=True)
lbox += (1.0 - iou).mean()
tobj[b, a, gj, gi] = (1.0 - self.gr) + self.gr * iou.detach().clamp(0).type(tobj.dtype)
selected_tcls = targets[i][:, 1].long()
if self.nc > 1:
t = torch.full_like(ps[:, 5:], self.cn, device=device)
t[range(n), selected_tcls] = self.cp
lcls += self.BCEcls(ps[:, 5:], t)
obji = self.BCEobj(pi[..., 4], tobj)
lobj += obji * self.balance[i]
if self.autobalance:
self.balance[i] = self.balance[i] * 0.9999 + 0.0001 / obji.detach().item()
if self.autobalance:
self.balance = [x / self.balance[self.ssi] for x in self.balance]
lbox *= self.hyp['box']
lobj *= self.hyp['obj']
lcls *= self.hyp['cls']
bs = tobj.shape[0]
return (lbox + lobj + lcls) * bs, torch.cat((lbox, lobj, lcls)).detach()
def build_targets(self, p, targets, imgs):
"""
网络输出predictions
p[0]-->[bs, 3, 80, 80, nc+5];
p[1]-->[bs, 3, 40, 40, nc+5];
p[2]-->[bs, 3, 20, 20, nc+5]
真实标签targets
[targets_num, 6], (index,c,x,y,w,h) index表示第几张图片; c为类别; xywh为坐标
"""
indices, anch = self.find_3_positive(p, targets)
matching_bs = [[] for pp in p]
matching_as = [[] for pp in p]
matching_gjs = [[] for pp in p]
matching_gis = [[] for pp in p]
matching_targets = [[] for pp in p]
matching_anchs = [[] for pp in p]
nl = len(p)
for batch_idx in range(p[0].shape[0]):
b_idx = targets[:, 0]==batch_idx
this_target = targets[b_idx]
if this_target.shape[0] == 0:
continue
txywh = this_target[:, 2:6] * imgs[batch_idx].shape[1]
txyxy = xywh2xyxy(txywh)
pxyxys = []
p_obj = []
p_cls = []
from_which_layer = []
all_b = []
all_a = []
all_gj = []
all_gi = []
all_anch = []
for i, pi in enumerate(p):
b, a, gj, gi = indices[i]
idx = (b == batch_idx)
b, a, gj, gi = b[idx], a[idx], gj[idx], gi[idx]
all_b.append(b)
all_a.append(a)
all_gj.append(gj)
all_gi.append(gi)
all_anch.append(anch[i][idx])
from_which_layer.append(torch.ones(size=(len(b),)) * i)
fg_pred = pi[b, a, gj, gi]
p_obj.append(fg_pred[:, 4:5])
p_cls.append(fg_pred[:, 5:])
grid = torch.stack([gi, gj], dim=1)
pxy = (fg_pred[:, :2].sigmoid() * 2. - 0.5 + grid) * self.stride[i]
pwh = (fg_pred[:, 2:4].sigmoid() * 2) ** 2 * anch[i][idx] * self.stride[i]
pxywh = torch.cat([pxy, pwh], dim=-1)
pxyxy = xywh2xyxy(pxywh)
pxyxys.append(pxyxy)
pxyxys = torch.cat(pxyxys, dim=0)
if pxyxys.shape[0] == 0:
continue
p_obj = torch.cat(p_obj, dim=0)
p_cls = torch.cat(p_cls, dim=0)
from_which_layer = torch.cat(from_which_layer, dim=0)
all_b = torch.cat(all_b, dim=0)
all_a = torch.cat(all_a, dim=0)
all_gj = torch.cat(all_gj, dim=0)
all_gi = torch.cat(all_gi, dim=0)
all_anch = torch.cat(all_anch, dim=0)
pair_wise_iou = box_iou(txyxy, pxyxys)
pair_wise_iou_loss = -torch.log(pair_wise_iou + 1e-8)
gt_cls_per_image = (
F.one_hot(this_target[:, 1].to(torch.int64), self.nc)
.float()
.unsqueeze(1)
.repeat(1, pxyxys.shape[0], 1)
)
num_gt = this_target.shape[0]
cls_preds_ = (
p_cls.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
* p_obj.unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
)
y = cls_preds_.sqrt_()
pair_wise_cls_loss = F.binary_cross_entropy_with_logits(
torch.log(y/(1-y)), gt_cls_per_image, reduction="none").sum(-1)
del cls_preds_
cost = (pair_wise_cls_loss + 3.0 * pair_wise_iou_loss)
matching_matrix = torch.zeros_like(cost)
top_k, _ = torch.topk(pair_wise_iou, min(10, pair_wise_iou.shape[1]), dim=1)
dynamic_ks = torch.clamp(top_k.sum(1).int(), min=1)
for gt_idx in range(num_gt):
_, pos_idx = torch.topk(cost[gt_idx], k=dynamic_ks[gt_idx].item(), largest=False)
matching_matrix[gt_idx][pos_idx] = 1.0
del top_k, dynamic_ks
anchor_matching_gt = matching_matrix.sum(0)
if (anchor_matching_gt > 1).sum() > 0:
_, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)
matching_matrix[:, anchor_matching_gt > 1] *= 0.0
matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1.0
fg_mask_inboxes = matching_matrix.sum(0) > 0.0
matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
from_which_layer = from_which_layer[fg_mask_inboxes]
all_b = all_b[fg_mask_inboxes]
all_a = all_a[fg_mask_inboxes]
all_gj = all_gj[fg_mask_inboxes]
all_gi = all_gi[fg_mask_inboxes]
all_anch = all_anch[fg_mask_inboxes]
this_target = this_target[matched_gt_inds]
for i in range(nl):
layer_idx = from_which_layer == i
matching_bs[i].append(all_b[layer_idx])
matching_as[i].append(all_a[layer_idx])
matching_gjs[i].append(all_gj[layer_idx])
matching_gis[i].append(all_gi[layer_idx])
matching_anchs[i].append(all_anch[layer_idx])
matching_targets[i].append(this_target[layer_idx])
for i in range(nl):
if matching_targets[i] != []:
matching_bs[i] = torch.cat(matching_bs[i], dim=0)
matching_as[i] = torch.cat(matching_as[i], dim=0)
matching_gjs[i] = torch.cat(matching_gjs[i], dim=0)
matching_gis[i] = torch.cat(matching_gis[i], dim=0)
matching_targets[i] = torch.cat(matching_targets[i], dim=0)
matching_anchs[i] = torch.cat(matching_anchs[i], dim=0)
else:
matching_bs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
matching_as[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
matching_gjs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
matching_gis[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
matching_targets[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
matching_anchs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
return matching_bs, matching_as, matching_gjs, matching_gis, matching_targets, matching_anchs
"""
匹配正样本
build_targets函数用于获得在训练时计算loss函数所需要的目标框,即被认为是正样本。
与yolov3/v4的不同: yolov5支持跨网格预测。
对于任何一个bbox, 三个输出预测特征层都可能有先验框anchors与之匹配;
该函数输出的正样本框比传入的targets(GT框)数目多;
具体处理过程:
(1) 对于任何一个尺度的特征图,计算当前bbox和当前层anchor的匹配程度,不采用iou,而是shape比例;
如果anchor和bbox的宽高比差距大于4,则认为不匹配,此时忽略相应的bbox,即当做背景。
(2) 然后对bbox落在的网格所有anchors都计算loss;
注意:此时落在的网格不再是一个而是附近的多个,这样就增加了正样本数,可能存在有些bbox在三个尺度都预测的情况;
另外,yolov5也没有conf分支忽略阈值(ignore_thresh)的操作,而yolov3/v4有。
"""
def find_3_positive(self, p, targets):
"""
p: 网络输出, p[i](b,3,h,w,nc+5) h,w分别为特征图的长宽, b为batch-size。
targets: GT框, targets(nt, 6), 6=i,c,x,y,w,h, i表示第i+1张图片, c为类别, 坐标xywh。
"""
na, nt = self.na, targets.shape[0]
indices, anch = [], []
gain = torch.ones(7, device=targets.device).long()
ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt)
targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2)
g = 0.5
off = torch.tensor([[0, 0],
[1, 0], [0, 1], [-1, 0], [0, -1],
], device=targets.device).float() * g
for i in range(self.nl):
anchors = self.anchors[i]
gain[2:6] = torch.tensor(p[i].shape)[[3, 2, 3, 2]]
t = targets * gain
if nt:
r = t[:, :, 4:6] / anchors[:, None]
j = torch.max(r, 1./r).max(2)[0] < self.hyp['anchor_t']
t = t[j]
gxy = t[:, 2:4]
gxi = gain[[2, 3]] - gxy
"""
把相对于各个网格左上角x<0.5或y<0.5和相对于右下角的x<0.5或y<0.5的框提取出来,也就是j,k,l,m;
在选取gij(也就是标签框分配给的网格)的时候,对这四个部分的框都做一个偏移(减去上面的offesets);
也就是下面的gij = (gxy - offsets).long操作;
再将这四个部分的框跟原始的gij拼接在一起,总共就是五个部分;
yolov3/v4仅仅采用当前网格的anchor进行回归; yolov4也有解决网格跑偏的措施,即通过sigmoid限制输出;
yolov5中心点回归从yolov3/v4的0~1范围变成-0.5~1.5的范围;
中心点回归的公式变为:
xy.sigmoid()*2.0 - 0.5 + cx (其中对原始中心点网格坐标扩展了两个邻居像素)
"""
j, k = ((gxy % 1. < g) & (gxy > 1.)).T
l, m = ((gxi % 1. < g) & (gxi > 1.)).T
j = torch.stack((torch.ones_like(j), j, k, l, m))
t = t.repeat((5, 1, 1))[j]
offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]
else:
t = targets[0]
offsets = 0
"""
对每个bbox找出对应的正样本anchor, 其中:
b表示当前bbox属于batch内部的第几张图片;
c是该bbox类别;
gxy是对应bbox的中心点坐标xy;
gwh是对应bbox的wh;
a表示当前bbox和当前层的第几个anchor匹配上;
gi, gj是对应的负责预测该bbox的网格坐标;
"""
b, c = t[:, :2].long().T
gxy = t[:, 2:4]
gwh = t[:, 4:6]
gij = (gxy - offsets).long()
gi, gj = gij.T
a = t[:, 6].long()
indices.append((b, a, gj.clamp_(0, gain[3]-1), gi.clamp_(0, gain[2]-1)))
anch.append(anchors[a])
return indices, anch

- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
- 79
- 80
- 81
- 82
- 83
- 84
- 85
- 86
- 87
- 88
- 89
- 90
- 91
- 92
- 93
- 94
- 95
- 96
- 97
- 98
- 99
- 100
- 101
- 102
- 103
- 104
- 105
- 106
- 107
- 108
- 109
- 110
- 111
- 112
- 113
- 114
- 115
- 116
- 117
- 118
- 119
- 120
- 121
- 122
- 123
- 124
- 125
- 126
- 127
- 128
- 129
- 130
- 131
- 132
- 133
- 134
- 135
- 136
- 137
- 138
- 139
- 140
- 141
- 142
- 143
- 144
- 145
- 146
- 147
- 148
- 149
- 150
- 151
- 152
- 153
- 154
- 155
- 156
- 157
- 158
- 159
- 160
- 161
- 162
- 163
- 164
- 165
- 166
- 167
- 168
- 169
- 170
- 171
- 172
- 173
- 174
- 175
- 176
- 177
- 178
- 179
- 180
- 181
- 182
- 183
- 184
- 185
- 186
- 187
- 188
- 189
- 190
- 191
- 192
- 193
- 194
- 195
- 196
- 197
- 198
- 199
- 200
- 201
- 202
- 203
- 204
- 205
- 206
- 207
- 208
- 209
- 210
- 211
- 212
- 213
- 214
- 215
- 216
- 217
- 218
- 219
- 220
- 221
- 222
- 223
- 224
- 225
- 226
- 227
- 228
- 229
- 230
- 231
- 232
- 233
- 234
- 235
- 236
- 237
- 238
- 239
- 240
- 241
- 242
- 243
- 244
- 245
- 246
- 247
- 248
- 249
- 250
- 251
- 252
- 253
- 254
- 255
- 256
- 257
- 258
- 259
- 260
- 261
- 262
- 263
- 264
- 265
- 266
- 267
- 268
- 269
- 270
- 271
- 272
- 273
- 274
- 275
- 276
- 277
- 278
- 279
- 280
- 281
- 282
- 283
- 284
- 285
- 286
- 287
- 288
- 289
- 290
- 291
- 292
- 293
- 294
- 295
- 296
- 297
- 298
- 299
- 300
- 301
- 302
- 303
- 304
- 305
- 306
- 307
- 308
- 309
- 310
- 311
- 312
- 313
- 314
- 315
- 316
- 317
- 318
- 319
- 320
- 321
- 322
- 323
- 324
- 325
- 326
- 327
- 328
- 329
- 330
- 331
- 332
- 333
- 334
- 335
- 336
- 337
- 338
- 339
- 340
- 341
- 342
- 343
- 344
- 345
- 346
- 347
- 348
- 349
- 350
- 351
- 352
- 353
- 354
- 355
- 356
- 357
- 358
- 359
- 360
- 361
- 362
- 363
- 364
- 365
- 366
- 367
- 368
- 369
- 370
- 371
- 372
- 373
- 374
- 375
- 376
- 377
- 378
- 379
- 380
- 381
- 382
- 383
- 384
- 385
- 386
- 387
- 388
- 389
- 390
- 391
- 392
- 393
- 394
- 395
- 396
- 397
- 398
- 399
- 400
- 401
- 402
- 403
- 404
- 405
- 406
- 407
- 408
- 409
- 410
- 411
- 412
- 413
- 414
- 415
- 416
- 417
- 418
- 419
- 420
- 421
- 422
- 423
- 424
- 425
- 426
- 427
- 428
- 429
- 430
- 431
- 432
- 433
- 434
- 435
- 436
- 437
- 438
- 439
- 440
- 441
- 442
- 443
- 444
- 445
- 446
- 447
- 448
- 449
- 450
- 451
- 452
- 453
- 454
- 455
- 456
- 457
- 458
- 459
- 460
- 461
- 462
- 463
- 464
- 465
- 466
- 467
- 468
- 469
- 470
- 471
- 472
- 473