• DN-DETR源码讲解


    一:创新点

    DN-DETR中的DN指的是denoising,即“去噪”,是一种训练时加快收敛速度的trick。作者将网络拆分为了Denoising Part和Matching Part,只有在训练的时才有Denoising Part,inference时去除。

    DN-DETR的主框架和Conditional DETR、DAB DETR完全类似,对它们还不熟悉的小伙伴可以看Conditional DETRDAB DETR这两篇文章。下面展示一下整体网络图:

    在这里插入图片描述

    二:源码分析

    DAB-DETR主模块

    • def init()

    在这里插入图片描述在这里插入图片描述

    • def forward()

    在这里插入图片描述在这里插入图片描述

    init初始化时生成了[91 + 1, 256]的self.label_enc和[10, 4]的self.refpoint_embed,分别是label词缀表(最后一维其实是初始化tgt)和refpoint的初始化。

    forward老生常谈,唯一的不同就是多了prepare_for_dn处理target数据,和dn post process对输出结果作拆分(将[3 2 30 91] 和 [3 2 30 4]分别拆为 [3 2 10 91]、[3 2 10 4]和[3 2 20 91]、[3 2 20 4],前两个作为真正的output和refpoints,后两个作为去噪后的labels和boxes结果扔进mask_dict中,用来计算去噪损失)。最后返回的是[3 2 10 91]、[3 2 10 4]和mask_dict。

    让我们看一下prepare_for_dn函数源码:

    在这里插入图片描述
    在这里插入图片描述在这里插入图片描述在这里插入图片描述在这里插入图片描述

    该函数的功能是由target中真实label和boxes,生成几组group的噪声target,然后拼接在一起。对于label是随机flip,而boxes则是改变center和w、h。得到[2, 20, 256]的input_label_embed和[2, 20, 4]的boxes。注意!最后还添加了[2 10 256]的tgt和[2 10 4]的refpoint_embed

    还有一个重点,就是attn_mask。作者在论文中提出如下见解:

    Therefore, our attention mask is to make sure the matching part cannot see the denoising part and the denoising groups
    cannot see each other as shown in Fig. 4.
    
    • 1
    • 2

    翻译一下就是在decoder时防止泄题,denoising part中各个group之间不能互相看到,matching part中的query不能看到denoising part中的groups。而denoising part中的groups看到matching part也没事,因为它们需要学习,里面不包含“答案”,

    最后输出[2 30 256]的input_query_label,[2 30 4]的input_query_bbox,[30 30]的attn_mask,和包含了大量原始target和索引的字典(其中内容请看源码中的注释,用来最后计算Loss用的)。将它们和src等输入到transformer中,下面看transformer模块:

    Transformer

    在这里插入图片描述中规中矩,encoder和decoder中的细节就不讲了,和DAB-DETR一字不差。最后输出[3 2 30 256]的hs和[3 2 30 4]的references。最后我们再看一下Loss的计算细节:

    Loss计算细节

    • engine.py
      在这里插入图片描述
    • criterion中的 forward(output, target)

    在这里插入图片描述在这里插入图片描述在这里插入图片描述

    Loss计算和DETR常规计算一样,只多了dn loss computation,计算去噪损失,下面是 dn_losses = compute_dn_loss(mask_dict, self.training, aux_num, self.focal_alpha)实现源码:

    • compute_dn_loss

    在这里插入图片描述
    在这里插入图片描述

    通过prepare_for_loss对mask_dict进行处理,将[3 2 20 256]的output_known_coord和[3 2 20 4]的output_known_class中多余的zero行去除,提取出group_num✖label_num个真正的去噪target,文中是35(5✖7=35)个,该函数输出是[3 35 4]和[3 35 91]。

    最后只对最后一维的结果计算去噪Loss,敲重点,这里计算[35 91]的label损失用的是focal函数,实现细节略。

    最后展现一下prepare_for_losstgt_loss_labelstgt_loss_boxes

    • prepare_for_loss
      在这里插入图片描述
    • tgt_loss_labels

    在这里插入图片描述

    • tgt_loss_boxes

    在这里插入图片描述


      至此我对DN-DETR源码中全部的流程与细节,进行了深度讲解,希望对大家有所帮助,有不懂的地方或者建议,欢迎大家在下方留言评论。

    我是努力在CV泥潭中摸爬滚打的江南咸鱼,我们一起努力,不留遗憾!

  • 相关阅读:
    基于单片机仓库温湿度监测报警系统仿真设计
    浅谈pytorch中的nn.Sequential(*net[3: 5])是啥意思
    新生儿早产:原因、科普和注意事项
    C#.NET与JAVA互通之DES加密V2024
    c++11 标准模板(STL)string(四)
    Vue基础语法的进阶,事件处理器,自定义组件及组件通信
    业务层、过滤器
    C++ 纯虚类实例化中对于引用成员的使用
    爱心html的网页
    Springboot 引入 Redis 并配置序列化和封装RedisTemplate
  • 原文地址:https://blog.csdn.net/weixin_43702653/article/details/126378645