码农知识堂 - 1000bd
  •   Python
  •   PHP
  •   JS/TS
  •   JAVA
  •   C/C++
  •   C#
  •   GO
  •   Kotlin
  •   Swift
  • PyTorch实现NMS算法


    PyTorch实现NMS算法

    • 介绍
      • 示例代码

    介绍

    参考链接1:NMS 算法源码实现
    参考链接2: Python实现NMS(非极大值抑制)对边界框进行过滤。
    目标检测算法(主流的有 RCNN 系、YOLO 系、SSD 等)在进行目标检测任务时,可能对同一目标有多次预测得到不同的检测框,非极大值抑制(NMS) 算法则可以确保对每个对象只得到一个检测,简单来说就是“消除冗余检测”。

    示例代码

    以下代码实现在 PyTorch 中实现非极大值抑制(NMS)。这个函数接受三个参数:boxes(边界框),scores(每个边界框的得分),和 iou_threshold(交并比阈值)。假设输入的边界框格式为 [x1, y1, x2, y2],其中 (x1, y1) 是左上角坐标,(x2, y2) 是右下角坐标。

    import torch
    
    def nms(boxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float):
        """
        Perform Non-Maximum Suppression (NMS) on bounding boxes.
    
        Args:
            boxes (torch.Tensor): A tensor of shape (N, 4) containing the bounding boxes
                                  of shape [x1, y1, x2, y2], where N is the number of boxes.
            scores (torch.Tensor): A tensor of shape (N,) containing the scores of the boxes.
            iou_threshold (float): The IoU threshold for suppressing boxes.
    
        Returns:
            torch.Tensor: A tensor of indices of the boxes to keep.
        """
        # Get the areas of the boxes
        x1 = boxes[:, 0]
        y1 = boxes[:, 1]
        x2 = boxes[:, 2]
        y2 = boxes[:, 3]
        areas = (x2 - x1) * (y2 - y1)
    
        # Sort the scores in descending order and get the sorted indices
        _, order = scores.sort(0, descending=True)
    
        keep = []
        while order.numel() > 0:
            if order.numel() == 1:
                i = order.item()
                keep.append(i)
                break
            else:
                i = order[0].item()
                keep.append(i)
    
            # Compute the IoU of the kept box with the rest
            xx1 = torch.max(x1[i], x1[order[1:]])
            yy1 = torch.max(y1[i], y1[order[1:]])
            xx2 = torch.min(x2[i], x2[order[1:]])
            yy2 = torch.min(y2[i], y2[order[1:]])
    
            w = torch.clamp(xx2 - xx1, min=0)
            h = torch.clamp(yy2 - yy1, min=0)
            inter = w * h
            iou = inter / (areas[i] + areas[order[1:]] - inter)
    
            # Keep the boxes with IoU less than the threshold
            inds = torch.where(iou <= iou_threshold)[0]
            order = order[inds + 1]
    
        return torch.tensor(keep, dtype=torch.long)
    

    代码工作原理:

    1. 计算每个边界框的面积。
    2. 根据得分对边界框进行降序排序。
    3. 依次选择得分最高的边界框,并计算它与其他边界框的 IoU。
    4. 保留 IoU 小于阈值的边界框,并继续处理剩余的边界框。
    5. 返回保留的边界框的索引。
  • 相关阅读:
    Vue进阶(六十七)页面刷新路由传参丢失问题分析及解决
    2016款奔驰C200车COMAND显示屏黑屏
    2.4 图解CIO工作指南(IT 架构模型) --- 构成IT 架构的技术要素
    【网页设计】期末大作业html+css(体育网站)--杜丹特篮球介绍8页 带报告
    日常Bug排查-集群逐步失去响应
    Linux上使用ldapsearch命令通过AD GC查询指定用户
    华为 huawei 交换机配置 Dot1q 终结子接口接入 L3VPN 示例
    Java IO流
    F.binary_cross_entropy、nn.BCELoss、nn.BCEWithLogitsLoss与F.kl_div函数详细解读
    2023NOIP A层联测19 多边形
  • 原文地址:https://blog.csdn.net/qq_36892712/article/details/139840971
  • 最新文章
  • 攻防演习之三天拿下官网站群
    数据安全治理学习——前期安全规划和安全管理体系建设
    企业安全 | 企业内一次钓鱼演练准备过程
    内网渗透测试 | Kerberos协议及其部分攻击手法
    0day的产生 | 不懂代码的"代码审计"
    安装scrcpy-client模块av模块异常,环境问题解决方案
    leetcode hot100【LeetCode 279. 完全平方数】java实现
    OpenWrt下安装Mosquitto
    AnatoMask论文汇总
    【AI日记】24.11.01 LangChain、openai api和github copilot
  • 热门文章
  • 十款代码表白小特效 一个比一个浪漫 赶紧收藏起来吧!!!
    奉劝各位学弟学妹们,该打造你的技术影响力了!
    五年了,我在 CSDN 的两个一百万。
    Java俄罗斯方块,老程序员花了一个周末,连接中学年代!
    面试官都震惊,你这网络基础可以啊!
    你真的会用百度吗?我不信 — 那些不为人知的搜索引擎语法
    心情不好的时候,用 Python 画棵樱花树送给自己吧
    通宵一晚做出来的一款类似CS的第一人称射击游戏Demo!原来做游戏也不是很难,连憨憨学妹都学会了!
    13 万字 C 语言从入门到精通保姆级教程2021 年版
    10行代码集2000张美女图,Python爬虫120例,再上征途
Copyright © 2022 侵权请联系2656653265@qq.com    京ICP备2022015340号-1
正则表达式工具 cron表达式工具 密码生成工具

京公网安备 11010502049817号