贡献/特点:
预测得到N(100)个预测框,gt为M个框,通常N>M,那么怎么计算损失呢?
这里呢,就先对这100个预测框和gt框进行一个二分图的匹配,先确定每个gt对应的是哪个预测框,最终再计算M个预测框和M个gt框的总损失。
其实很简单,假设现在有一个矩阵,横坐标就是我们预测的100个预测框,纵坐标就是gt框,再分别计算每个预测框和其他所有gt框的cost,这样就构成了一个cost matrix,再确定把如何把所有gt框分配给对应的预测框,才能使得最终的总cost最小。
这里计算的方法就是很经典的匈牙利算法,通常是调用scipy包中的linear_sum_assignment函数来完成。这个函数的输入就是cost matrix,输出一组行索引和一个对应的列索引,给出最佳分配。
匈牙利算法通常用来解决二分图匹配问题,具体原理可以看这里: 二分图匈牙利算法的理解和代码 和 算法学习笔记(5):匈牙利算法
所以通过以上的步骤,就确定了最终100个预测框中哪些预测框会作为有效预测框,哪些预测框会称为背景。再将有效预测框和gt框计算最终损失(有效预测框个数等于gt框个数)。
损失函数:分类损失+回归损失
分类损失:交叉熵损失,去掉log
回归损失:GIOU Loss + L1 Loss
DETR前向传播流程:
论文原文给出的掉包版代码,mAP好像有40,虽然比源码低了2个点,但是代码很简单,只有40多行,方便我们了解整个detr的网络结构:
import torch
from torch import nn
from torchvision.models import resnet50
class DETR(nn.Module):
def __init__(self, num_classes, hidden_dim, nheads, num_encoder_layers, num_decoder_layers):
super().__init__()
# backbone = resnet50 除掉average pool和fc层 只保留conv1 - conv5_x
self.backbone = nn.Sequential(*list(resnet50(pretrained=True).children())[:-2])
# 1x1卷积降维 2048->256
self.conv = nn.Conv2d(2048, hidden_dim, 1)
# 6层encoder + 6层decoder hidden_dim=256 nheads多头注意力机制 8头 num_encoder_layers=num_decoder_layers=6
self.transformer = nn.Transformer(hidden_dim, nheads, num_encoder_layers, num_decoder_layers)
# 分类头
self.linear_class = nn.Linear(hidden_dim, num_classes + 1)
# 回归头
self.linear_bbox = nn.Linear(hidden_dim, 4)
# 位置编码 encoder输入
self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
# query pos编码 decoder输入
self.query_pos = nn.Parameter(torch.rand(100, hidden_dim))
def forward(self, inputs):
x = self.backbone(inputs) # [1,3,800,1066] -> [1,2048,25,34]
h = self.conv(x) # [1,2048,25,34] -> [1,256,25,34]
H, W = h.shape[-2:] # H=25 W=34
# pos = [850,1,256] self.col_embed = [50,128] self.row_embed[:H]=[50,128]
pos = torch.cat([self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
], dim=-1).flatten(0, 1).unsqueeze(1)
# encoder输入 decoder输入
h = self.transformer(pos + h.flatten(2).permute(2, 0, 1), self.query_pos.unsqueeze(1))
return self.linear_class(h), self.linear_bbox(h).sigmoid()
detr = DETR(num_classes=91, hidden_dim=256, nheads=8, num_encoder_layers=6, num_decoder_layers=6)
detr.eval()
inputs = torch.randn(1, 3, 800, 1066)
logits, bboxes = detr(inputs)
print(logits.shape) # torch.Size([100, 1, 92])
print(bboxes.shape) # torch.Size([100, 1, 4])
1、为什么ViT只有Encoder,而DETR要用Encoder+Decoder?(从论文实验部分得出结论)
Encoder:Encoder自注意力主要进行全局建模,学习全局的特征,通过这一步其实已经基本可以把图片中的各个物体尽可能的分开;
Decoder:这个时候再使用Decoder自注意力,再做目标检测和分割任务,模型就可以进一步把物体的边界的极值点区域进行一个更进一步精确的划分,让边缘的识别更加精确;
2、object query有什么用?
object query是用来替换anchor的,通过引入可学习的object query,可以让模型自动的去学习图片当中哪些区域是可能有物体的,最终通过object query可以找到100个这种可能有物体的区域。再后面通过二分图匹配的方式找到100个预测框中有效的预测框,进而计算损失即可。
所以说object query就起到了替换anchor的作用,以可学习的方式找到可能有物体的区域,而不会因为使用anchor而造成大量的冗余框。
b站: DETR 论文精读【论文精读】