• NMS原理及其代码实现


    1. 为什么要用NMS

    YOLOv3在预测阶段, 每个目标至少会生成3个proposals, 但一个目标一般只显示一个proposal, 因此需要对proposals进行去重,
    这里去重的方法是NMS. 而NMS的筛选依据是IoU.

    2. NMS的步骤

    1. 先对所有proposals进行置信度(confidence)的排序, 按照置信度的大小进行降序排序(从大到小排序).

    2. 将最大置信度的proposal ( p r o p o s a l m a x {\rm proposal_{max}} proposalmax)取出来, 与剩下的proposals( p r o p o s a l s r e s t {\rm proposals_{rest}} proposalsrest)进行IoU的计算, 这里的目的是筛选后面的proposals, 意思是说:

      • 如果 p r o p o s a l s r e s t {\rm proposals_{rest}} proposalsrest中的某一个proposal与 p r o p o s a l m a x {\rm proposal_{max}} proposalmax之间计算得到的IoU的值小于设定的阈值( t h r e s h {\rm thresh} thresh), 那么就认为这个proposal是可用的(可保留的)

      • 一旦有IoU > t h r e s h {\rm thresh} thresh的proposal, 我们则认为该proposal和 p r o p o s a l m a x {\rm proposal_{max}} proposalmax预测的是同一个目标(object), 因此该框就冗余了(因为它的置信度没有 p r o p o s a l m a x {\rm proposal_{max}} proposalmax的置信度高), 需要去除.

        置信度最大的proposal是一定会保留的, 我们是在挑选剩余的proposals

    3. p r o p o s a l m a x {\rm proposal_{max}} proposalmax与所有 p r o p o s a l s r e s t {\rm proposals_{rest}} proposalsrest的IoU计算和筛选完毕后, 置信度指针指向下一个(第二高置信度的proposal)

    4. 重复2, 3 -> 递归

    NMS的核心主要事项就是: IoU > 阈值 的proposals需要去除, < 阈值的proposals则保留, 目的是不影响预测其他object的proposals受到影响.

    3. NMS的代码实现

    import torch
    
    
    def iou(proposal, proposals, isMin=False):
        """计算proposals的IoU
            在计算IoU时, 需要求二者交集和并集. 假设两个框的坐标分别为: (x_11, y_11, x_12, y_12)和(x_21, y2_1, x_22, y_22)
                交集框的坐标: (max(x_11, x_21), max(y_11, y_21), min(x_12, x_22), min(y_12, y_22))
    
        Args:
            proposal (_type_): 置信度最高的proposal -> [4]
            proposals (_type_): 剩余的proposals -> [N, 4]
            isMin (bool, optional): IoU的计算模式, 有两种:
                                                        1. (True) 交集 / 最小面积
                                                        2. (False -> Default) 交集 / 并集
        Return:
            IoU (float): 返回proposal与proposals的IoU
        """
        # 计算当前框的面积: proposal = [x, y, w, h]
        box_area = (proposal[2] - proposal[0]) * (proposal[3] - proposal[1])
        
        # 计算proposals中所有框的面积 proposals = [N, [x, y, w, h]]
        boxes_area = (proposals[:, 2] - proposals[:, 0]) * (proposals[:, 3] - proposals[:, 1])
    
        # 计算交集proposal和proposals的计算
        xx_1 = torch.maximum(proposal[0], proposals[:, 0])  # 交集的左上角x坐标
        yy_1 = torch.maximum(proposal[1], proposals[:, 1])  # 交集的左上角y坐标
        xx_2 = torch.minimum(proposal[2], proposals[:, 2])  # 交集的右下角x坐标
        yy_2 = torch.minimum(proposal[3], proposals[:, 3])  # 交集的右下角y坐标
        
        # 特殊情况: 两个框没有挨着 -> 没有交集
        w, h = torch.maximum(torch.Tensor([0]), xx_2 - xx_1), torch.maximum(torch.Tensor([0]), yy_2 - yy_1)
        
        # 获取交集的框的面积
        intersection_area = w * h
        
        if isMin:  # 如果一个框在另一框的内部
            return intersection_area / torch.min(box_area, boxes_area)
        
        else:  # 两个框相交 -> 交集 / 并集
            return intersection_area / (box_area + boxes_area - intersection_area)
        
        
    def nms(proposals, thresh=0.3, isMin=False):
        """非极大值抑制用来去除冗余的proposals
    
        Args:
            proposals (torch.tensor): 网络推理得到的proposals -> [conf, x, y, w, h]
            thresh (float, optional): NMS筛选的阈值. Defaults to 0.3.
            isMin (bool, optional): IoU的计算方式, 默认为交集/并集. Defaults to False.
        """
        
        # 根据proposals的置信度进行降序排序
        sorted_proposals = proposals[proposals[:, 0].argsort(descending=True)]
    
        # 定义一个ls, 用来保存需要保留的proposals
        keep_boxes = []
        
        while len(sorted_proposals) > 0:
            # 取出置信度最高的proposal并存放到ls中
            _box = sorted_proposals[0]  
            keep_boxes.append(_box)
            
            if len(sorted_proposals) > 1:
                # 取出剩余的proposals
                _boxes = sorted_proposals[1:]
                
                # 置信度最高的proposal与其他proposals进行IoU的计算
                """
                    需要注意的是, NMS在筛选的时候是保留IoU小于thresh的. 为什么?
                        两个proposal的IoU越小, 说明两个proposal框起来的对象越不一样, 别忘了, NMS是为了去重, 所以需要保留小于IoU的proposals
    
                    torch.where(条件): 返回符合条件的索引
                """
                sorted_proposals = _boxes[torch.where(iou(_box, _boxes, isMin) < thresh)]
            
            # 当剩下最后一个时候, 就不进行IoU计算了(自己与自己计算IoU没有意义)
            else:
                break
        
        # 将ls转换为高维的tensor
        return torch.stack(keep_boxes)
    
    
    if __name__ == "__main__":
        proposal = torch.tensor(data=[0, 0, 4, 4])
        proposals = torch.tensor(data=[[4, 4, 5, 5],   # 没有交集
                                       [1, 1, 5, 5]])  # 有交集
        
        print(iou(proposal, proposals))  # tensor([0.0000, 0.3913])
    
        
        boxes = torch.tensor(data=[
                                   [0.5, 1, 1, 10, 10],
                                   [0.9, 1, 1, 11, 11],  # 和上面那个很相似
                                   [0.4, 8, 8, 12, 12]  # 和上面两个都不相似
        ])
        
        print(nms(boxes, thresh=0.1))
        """仅保留了2个
        tensor([[ 0.9000,  1.0000,  1.0000, 11.0000, 11.0000],
                [ 0.4000,  8.0000,  8.0000, 12.0000, 12.0000]])
        """
        
        print(nms(boxes, thresh=0.3))
        
        """全部都保留了
        tensor([[ 0.9000,  1.0000,  1.0000, 11.0000, 11.0000],
                [ 0.5000,  1.0000,  1.0000, 10.0000, 10.0000],
                [ 0.4000,  8.0000,  8.0000, 12.0000, 12.0000]])
        """
        ```
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
  • 相关阅读:
    STM32 Cubemx配置SPI编程(使用Flash模块)
    Redis常用命令
    Swift创建单例
    Java项目:SSM个人博客管理系统
    matplotlib show, ion, ioff, clf, pause的作用
    【qml】性能优化 | 常见的界面元素优化
    LeetCode 周赛 340,质数 / 前缀和 / 极大化最小值 / 最短路 / 平衡二叉树
    InVideo AI:用人工智能轻松制作视频
    如何跟踪网络路由链路&检测网络健康状况
    Power BI 傻瓜入门 7. 清理、转换和加载数据
  • 原文地址:https://blog.csdn.net/weixin_44878336/article/details/126163030