• nms非极大抑制


    基本原理

    1. 获取batch size,类别数,候选框基本信息
    2. 将传入的所有候选框先通过置信度进行排序
    3. 取置信度最高的作为作为第一个正解,然后通过iou_thres剔除重复检测
    4. 小于iou_thres阈值的重复第2,3步直到得到全部正解

    原理图

    请添加图片描述

    代码解析

    part.1 参数

    (1)prediction
    pred = model(img, augment=False)[0]
    
    • 1

    调用关系是:model—>forward—>_forward_once
    _forward_once的返回值 : 一个tensor list 存放三个元素 [bs, anchor_num, grid_w, grid_h, xywh+c+20classes]

     def forward(self, x, augment=False, profile=False, visualize=False):
            return self._forward_once(x, profile, visualize)  # single-scale inference, train
    
     def _forward_once(self, x, profile=False, visualize=False):
            y, dt = [], []  # outputs
            for m in self.model:
             #前向推理每一层结构   m.i=index   m.f=from   m.type=类名   m.np=number of params
            # if not from previous layer   m.f=当前层的输入来自哪一层的输出  s的m.f都是-1
                if m.f != -1:  # if not from previous layer
                    x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]  # from earlier layers
                if profile:
                    self._profile_one_layer(m, x, dt)
                x = m(x)  # run
                y.append(x if m.i in self.save else None)  # save output
                if visualize:
                    feature_visualization(x, m.type, m.i, save_dir=visualize)
            return x
           
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18

    其实你可以打开yolov5s.yaml,m.f != -1只有4个concat操作和1个Detect操作
    e.g.

    [[-1, 6], 1, Concat, [1]],  # cat backbone P4
    [[17, 20, 23], 1, Detect, [nc, anchors]]
    
    • 1
    • 2

    concat操作和Detect操作的第一位都是list类型

    if m.f != -1:  # if not from previous layer
    	x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]  # from earlier layers
    
    • 1
    • 2

    这段代码本小白也是理解一段时间,其实是等价于

    y=[]
    if isinstance(m.f, int):
    	x=y.append(m.f)
    	else:
    		for j in m.f:
    			if j == -1:x=y.append(x)
    			else:x=y.append(y[j])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    说白了就是将取出输入不是或不只是从上一层 的对应的层的结果,准备后面的进入对应m的forward()

    predict的结构

    nc = prediction.shape[2] - 5 # number of classes
    
    • 1

    prediction是网络模型的直接输出
    输出其shape是(1, 50000, 7), 1表示的是图片的个数,50000表示是网络预测的候选框的个数,7表示一组数其意义如下:
    在这里插入图片描述

    由这个可以看出来,nc可以得到网络预测的类别个数。

    xc = prediction[, 4] > conf_thres # 目标框中含有目标的概率值
    
    • 1

    由上图可知,prediction[…, 4]是第5个框,代表含有目标的概率,> conf_thres使结果变成bool类型,表示是否判断当前目标框中含有目标。prediction[…, 4]的shape是(1, 50000),因此xc也是shape为(1, 50000),类型为bool的一个tensor。

    在后面通过检查其是否为TRUE(判断目标框中含有目标):

    # If none remain process next image
            if not x.shape[0]:
                continue
    
    • 1
    • 2
    • 3

    一些变量的含义

     max_wh = 7680  # (像素)最大框宽和高度
        max_nms = 30000  # torchvision.ops.nms()中的最大候选框数量
        time_limit = 0.3 + 0.03 * bs  # seconds to quit after
        redundant = True  # require redundant detections
        multi_label &= nc > 1  # multiple labels per box (adds 0.5ms/img)
        merge = False  # use merge-NMS
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    定义输出

    output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
    
    • 1
    • prediction.shape[0]确定了输出是长度为1,包含六个张量的list

    part.2完整代码解读

    很多内容参照,学习yolov5 nms 源码理解
    以下代码取自yolov5 v6.x

    def non_max_suppression(prediction,
                            conf_thres=0.25,
                            iou_thres=0.45,
                            classes=None,
                            agnostic=False,
                            multi_label=False,
                            labels=(),
                            max_det=300):
        """Non-Maximum Suppression (NMS) on inference results to reject overlapping bounding boxes
    
        Returns:
             list of detections, on (n,6) tensor per image [xyxy, conf, cls]
        """
    
        bs = prediction.shape[0]  # batch size
        nc = prediction.shape[2] - 5  # number of classes
        xc = prediction[..., 4] > conf_thres  # candidates
    
        # Checks
        assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
        assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
    
        # Settings
        # min_wh = 2  # (pixels) minimum box width and height
        max_wh = 7680  # (pixels) maximum box width and height
        max_nms = 30000  # maximum number of boxes into torchvision.ops.nms()
        time_limit = 0.3 + 0.03 * bs  # seconds to quit after
        redundant = True  # require redundant detections
        multi_label &= nc > 1  # multiple labels per box (adds 0.5ms/img)
        merge = False  # use merge-NMS
    
        t = time.time()
        output = [torch.zeros((0, 6), device=prediction.device)] * bs
        for xi, x in enumerate(prediction):  # image index, image inference
            # Apply constraints
            # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0  # width-height
            x = x[xc[xi]]  # confidence
    
            # Cat apriori labels if autolabelling
            if labels and len(labels[xi]):
                lb = labels[xi]
                v = torch.zeros((len(lb), nc + 5), device=x.device)
                v[:, :4] = lb[:, 1:5]  # box
                v[:, 4] = 1.0  # conf
                v[range(len(lb)), lb[:, 0].long() + 5] = 1.0  # cls
                x = torch.cat((x, v), 0)
    
            # If none remain process next image
            if not x.shape[0]:
                continue
    
            # Compute conf
            x[:, 5:] *= x[:, 4:5]  # conf = obj_conf * cls_conf
    
            # Box (center x, center y, width, height) to (x1, y1, x2, y2)
            box = xywh2xyxy(x[:, :4])
    
            # Detections matrix nx6 (xyxy, conf, cls)
            if multi_label:
                i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
                x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
            else:  # best class only
                conf, j = x[:, 5:].max(1, keepdim=True)
                x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
    
            # Filter by class
            if classes is not None:
                x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
    
            # Apply finite constraint
            # if not torch.isfinite(x).all():
            #     x = x[torch.isfinite(x).all(1)]
    
            # Check shape
            n = x.shape[0]  # number of boxes
            if not n:  # no boxes
                continue
            elif n > max_nms:  # excess boxes
                x = x[x[:, 4].argsort(descending=True)[:max_nms]]  # sort by confidence
    
            # Batched NMS
            c = x[:, 5:6] * (0 if agnostic else max_wh)  # classes
            boxes, scores = x[:, :4] + c, x[:, 4]  # boxes (offset by class), scores
            i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS
            if i.shape[0] > max_det:  # limit detections
                i = i[:max_det]
            if merge and (1 < n < 3E3):  # Merge NMS (boxes merged using weighted mean)
                # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
                iou = box_iou(boxes[i], boxes) > iou_thres  # iou matrix
                weights = iou * scores[None]  # box weights
                x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True)  # merged boxes
                if redundant:
                    i = i[iou.sum(1) > 1]  # require redundancy
    
            output[xi] = x[i]
            if (time.time() - t) > time_limit:
                LOGGER.warning(f'WARNING: NMS time limit {time_limit:.3f}s exceeded')
                break  # time limit exceeded
    
        return output
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101

    细节上

    如果刷性能分的话,

    multi_label &= nc > 1  # multiple labels per box (adds 0.5ms/img)
    
    • 1

    没有必要将所有类别都记入,这样会增加时间成本

  • 相关阅读:
    【电力电子技术】电力电子器件:概述
    Mac brew 安装与使用
    AI全栈大模型工程师(三)GPT 能干什么?
    golang 控制并发有两种经典方式:WaitGroup 和 Context
    OpenGL ES之3D模型加载和渲染
    Metabase学习教程:仪表盘-3
    Java笔试题
    springboot缓存
    C#重启 --- 数据类型
    [附源码]Python计算机毕业设计Django基于VUE的网上订餐系统
  • 原文地址:https://blog.csdn.net/weixin_50862344/article/details/126606218