• Generalized Focal Loss v2 原理与代码解析


    paper:Generalized Focal Loss V2: Learning Reliable Localization Quality Estimation for Dense Object Detection

    code:GitHub - implus/GFocalV2: Generalized Focal Loss V2: Learning Reliable Localization Quality Estimation for Dense Object Detection, CVPR2021

    背景

    单阶段目标检测模型中除了分类和回归分支外,还常常用到定位质量估计(Localization Quality Estimcation, LQE)分支,在推理阶段LQE score经常与分类score相乘作为最终得分,因此在LQE的帮助下,高质量的边界框得分往往高于低质量的边界框,大大减小了NMS中高质量框被错误过滤掉的风险。

    之前的模型中的LQE包括YOLO中的Objectness,IoU-Net中的IoU,FCOS中的Centerness,这些方法都有一个共同的特点就是它们都是基于原始的卷积特征,比如点、边界或区域的特征来估计定位质量,如下图(a)-(g)所示。

    文本的创新点

    本文直接利用边界框分布的统计数据来评估定位质量,边界框分布在Generalized Focal Loss v1中提出,它学习每个预测边的离散概率分布,来描述边框回归的不确定性,如下图(a)所示。作者观察到,边框回归的一般分布统计和其真实定位质量有很强的相关性,如下图(b)所示。具体就是,分布的形状(平整度)可以清晰地反应预测结果的定位质量,分布越尖锐,预测结果越准确,反之亦然。因此很自然的就想到,用分布的统计信息来指导定位质量估计的学习,作者提出了一个轻量的子网络Distribution-Guided Quality Predictor(DGQP),利用边框分布统计信息来得到更可靠的LQE score。本文在Generalized Focal Loss v1的基础上,增加了DGQP模块,提出了一种新的dense object detector,Generlized Focal Loss v2,精度进一步得到提升。 

    方法介绍

    上面提到了学习到的边界回归分布的flatness与最终预测框的质量高度相关,一些相关的统计数据可以反映分布的flatness,和GFLv1一样,采用anchor point到gt四边的距离作为回归的target,记左右上下四边分别为 {l,r,t,b}" role="presentation">{l,r,t,b},定义 w" role="presentation">w 边的离散分布为 Pw=[Pw(y0),Pw(y1),...,Pw(yn)],w{l,r,t,b}" role="presentation">Pw=[Pw(y0),Pw(y1),...,Pw(yn)],w{l,r,t,b},作者提出使用每条边分布的Top-k和均值然后拼接起来作为统计特征 FR4(k+1)" role="presentation">FR4(k+1)

    其中Topkm(·)表示Top-k和其均值的联合运算,Concat(·)表示通道拼接,选择Top-k和其均值作为统计输入有两个好处:(1)由于 Pw" role="presentation">Pw 的和是固定的即 i=0nPw(yi)=1" role="presentation">i=0nPw(yi)=1,因此Top-k和其均值反映了分布的平整度:值越大,分布越尖锐,越小,越平整。(2)Top-k和均值可以使统计特征对其在分布域上的相对偏移不敏感,如下图所示,从而可以得到一个不受对象尺度影响的鲁棒表示。

     

    给定一般分布的统计特征 F" role="presentation">F 作为输入,作者设计了一个非常轻量的子网络 F()" role="presentation">F() 来预测最终的IoU质量估计。这个子网络只包含两个全连接层,分别接ReLU和Sigmoid,最终IoU标量I计算公式如下

    其中 δ" role="presentation">δσ" role="presentation">σ 分别表示ReLU和Sigmoid,W1Rp×4(k+1)" role="presentation">W1Rp×4(k+1)W2R1×p" role="presentation">W2R1×pk" role="presentation">k 表示Top-k,p" role="presentation">p 是是隐藏层的维度(文本中分别设置k=4,p=64" role="presentation">k=4,p=64),GFLv2的整体结构如下图所示,其中红色虚线框部分就是DGQP

     

    代码解析

    下面是GFL v1最终的分类和回归head的实现,其中输入x是经backbone和neck后的单层输出特征图

    1. def forward_single(self, x, scale):
    2. """Forward feature of a single scale level.
    3. Args:
    4. x (Tensor): Features of a single scale level.
    5. scale (:obj: `mmcv.cnn.Scale`): Learnable scale module to resize
    6. the bbox prediction.
    7. Returns:
    8. tuple:
    9. cls_score (Tensor): Cls and quality joint scores for a single
    10. scale level the channel number is num_classes.
    11. bbox_pred (Tensor): Box distribution logits for a single scale
    12. level, the channel number is 4*(n+1), n is max value of
    13. integral set.
    14. """
    15. cls_feat = x # (2,256,38,38)
    16. reg_feat = x
    17. for cls_conv in self.cls_convs:
    18. cls_feat = cls_conv(cls_feat)
    19. # (2,256,38,38)
    20. for reg_conv in self.reg_convs:
    21. reg_feat = reg_conv(reg_feat)
    22. # (2,256,38,38)
    23. cls_score = self.gfl_cls(cls_feat) # (2,20,38,38)
    24. bbox_pred = scale(self.gfl_reg(reg_feat)).float() # (2,68,38,38), 68=4x(16+1)
    25. return cls_score, bbox_pred

    下面是GFL v2的实现

    其中对回归head的输出bbox_pred进行softmax后计算topk(k=4)并与其均值拼接得到统计输入stat,然后输入子网络reg_conf,子网络包含两层全连接层,其中self.total_dim=k+1=5,self.reg_channels=64,得到质量估计quality_score再与分类score相乘作为最终的分类得分。

    1. conf_vector = [nn.Conv2d(4 * self.total_dim, self.reg_channels, 1)]
    2. conf_vector += [self.relu]
    3. conf_vector += [nn.Conv2d(self.reg_channels, 1, 1), nn.Sigmoid()]
    4. self.reg_conf = nn.Sequential(*conf_vector)
    5. def forward_single(self, x, scale):
    6. """Forward feature of a single scale level.
    7. Args:
    8. x (Tensor): Features of a single scale level.
    9. scale (:obj: `mmcv.cnn.Scale`): Learnable scale module to resize
    10. the bbox prediction.
    11. Returns:
    12. tuple:
    13. cls_score (Tensor): Cls and quality joint scores for a single
    14. scale level the channel number is num_classes.
    15. bbox_pred (Tensor): Box distribution logits for a single scale
    16. level, the channel number is 4*(n+1), n is max value of
    17. integral set.
    18. """
    19. cls_feat = x
    20. reg_feat = x
    21. for cls_conv in self.cls_convs:
    22. cls_feat = cls_conv(cls_feat)
    23. for reg_conv in self.reg_convs:
    24. reg_feat = reg_conv(reg_feat)
    25. bbox_pred = scale(self.gfl_reg(reg_feat)).float()
    26. N, C, H, W = bbox_pred.size()
    27. prob = F.softmax(bbox_pred.reshape(N, 4, self.reg_max + 1, H, W), dim=2)
    28. prob_topk, _ = prob.topk(self.reg_topk, dim=2)
    29. if self.add_mean:
    30. stat = torch.cat([prob_topk, prob_topk.mean(dim=2, keepdim=True)],
    31. dim=2)
    32. else:
    33. stat = prob_topk
    34. quality_score = self.reg_conf(stat.reshape(N, -1, H, W))
    35. cls_score = self.gfl_cls(cls_feat).sigmoid() * quality_score
    36. return cls_score, bbox_pred

    实验

    消融实验

    Combination of Input Statistics

    除了Top-k之外还有其它可以反应分布特征的统计值,比如均值和方差,作者进行了对比实验,结果如下,可以看出选用Top-k和均值时精度最高。

    Structure of DGQP

    作者对DGQP中参数k和p的选择进行了对比实验,结果如下,可以看出当k=4, p=64时精度最高。

    Type of Input Features 

    和之前基于原始卷积特征不同,本文提出的DGQP是第一个使用边界框分布的统计来生成定位质量估计的方法,作者和之前基于卷积特征的方法(如图(2)中的a-g,包括点、区域、边界等的特征)进行了对比实验,结果如下,可以看出本文提出的基于边界框分布的统计来生成LQE不仅精度最高速度也最快。

    Usage of the Decomposed Form

    本文采用的是decomposed形式,即定位质量估计得分 I" role="presentation">I 与分类得分 C" role="presentation">C 相乘作为最终分类得分的形式,composed的形式时将两者进行拼接然后再经常全连接层得到最终分类得分的形式,如下图所示。

    结果如下表所示,可以看出,decomposed形式精度更高。 

    Compatibility for Dense Detectors

    作者将GFLv2(包括本文提出的DGQP结构,和GFL v1中的一般分布表示)应用到其它dense目标检测模型中,结果如下所示,可以看出在其它常见的检测模型中也都获得了精度提升。

    Comparisions with State-of-the-arts

    GFL v2和其它检测模型精度对比如下表所示

    精度-速度权衡的可视化结果如下图所示

    参考

    大白话 Generalized Focal Loss V2 - 知乎 

  • 相关阅读:
    新唐NUC980使用记录:U-Boot & Linux 编译与烧录(基于SD1位置SD卡)
    Vue-router 路由间参数传递看完让你明明白白!
    【驯服野生verilog-mode全记录】day2 —— 模块的例化
    【网页期末作业】用HTML+CSS做一个漂亮简单的学校官网
    算法训练第五十八天
    C++对象池
    Java项目-网页聊天程序
    Oracle常用函数
    美团面试——后端开发岗
    【Java八股文总结】之Java Web
  • 原文地址:https://blog.csdn.net/ooooocj/article/details/127598340