传统DETR提出的encoder-decoder结构,将transformer运用到了目标检测领域,在我看来属于Resnet相对于Alexnet的里程碑级别,思路很开辟但是细节还欠打磨,我分析一下DETR中的缺点:
the deficits of Transformer attention in handling image feature maps as key elements,Modern object detectors use high-resolution feature maps to better detect small objects. However, high-resolution feature maps would lead to an unacceptable complexity for the self-attention module in the Transformer encoder of DETR, which has a quadratic complexity with the spatial size of input feature maps。究其原因是特征图处理模块少,也没有什么类似FPN这种低维和高维特征融合的手段。针对以上的几个问题,Deformable DETR依次提出如下思路:

self.query_embed = nn.Embedding(num_queries, hidden_dim*2),即(10, 128),10是代码中设置的query_num。值得注意的是128,因为这里的self.query_embed一半是tgt,一半是pos_embeds。
memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten),让我们一起进入encoder模块看一看
self.get_reference_points(spatial_shapes, valid_ratios, device=src.device)得到reference_points,shape为 [2, 15060, 4 , 2],得到的是在每一层特征图中的相对位置(0 ~ 1)。处理方法如下:
下面让我们重点看一下网络核心模块MSDeformAttn,对应着self.self_attn()
就是将加了pos_embeds的srcs作为query传入,通过Linear生成sampling_offsets和attention_weights,分别对应着每个query的每个head在每个特征层选取的4个keys和权重,可见这里的weight不是QK后生成的,而是直接Linear得到的。
最后传入MSDeformAttnFunction功能模块进行特征融合,实现细节略,输出memory。
结束了encoder模块,输出了memory。退回到deformable_transformer模块:

可见,就是将10个query_embed做了一下复制、拆分,得到真正的query_embed(decoder中也作为query_pos)和tgt,接着将query_embed传入Linear中得到reference_points,最后都传入Decoder中



最后,让我们回到Deformable_Detr模块,从self.transformer中输出结果如下:
后面根据任务转换输出结果的channels,之后就是基本的匈牙利匹配➕损失计算了,和Detr差不多。有一点值得注意,bbox的pred结果是reference_point + self.bbox_embed(hs[i])[…,:2]。相当于网络输出预测是长、宽和基于reference_point的偏移量!!!
至此我对Deformable DETR源码中全部的流程与细节,进行了深度讲解,希望对大家有所帮助,有不懂的地方或者建议,欢迎大家在下方留言评论。
我是努力在CV泥潭中摸爬滚打的江南咸鱼,我们一起努力,不留遗憾!