参考链接1:NMS 算法源码实现
参考链接2: Python实现NMS(非极大值抑制)对边界框进行过滤。
目标检测算法(主流的有 RCNN 系、YOLO 系、SSD 等)在进行目标检测任务时,可能对同一目标有多次预测得到不同的检测框,非极大值抑制(NMS) 算法则可以确保对每个对象只得到一个检测,简单来说就是“消除冗余检测”。
以下代码实现在 PyTorch 中实现非极大值抑制(NMS)。这个函数接受三个参数:boxes
(边界框),scores
(每个边界框的得分),和 iou_threshold
(交并比阈值)。假设输入的边界框格式为 [x1, y1, x2, y2]
,其中 (x1, y1)
是左上角坐标,(x2, y2)
是右下角坐标。
import torch
def nms(boxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float):
"""
Perform Non-Maximum Suppression (NMS) on bounding boxes.
Args:
boxes (torch.Tensor): A tensor of shape (N, 4) containing the bounding boxes
of shape [x1, y1, x2, y2], where N is the number of boxes.
scores (torch.Tensor): A tensor of shape (N,) containing the scores of the boxes.
iou_threshold (float): The IoU threshold for suppressing boxes.
Returns:
torch.Tensor: A tensor of indices of the boxes to keep.
"""
# Get the areas of the boxes
x1 = boxes[:, 0]
y1 = boxes[:, 1]
x2 = boxes[:, 2]
y2 = boxes[:, 3]
areas = (x2 - x1) * (y2 - y1)
# Sort the scores in descending order and get the sorted indices
_, order = scores.sort(0, descending=True)
keep = []
while order.numel() > 0:
if order.numel() == 1:
i = order.item()
keep.append(i)
break
else:
i = order[0].item()
keep.append(i)
# Compute the IoU of the kept box with the rest
xx1 = torch.max(x1[i], x1[order[1:]])
yy1 = torch.max(y1[i], y1[order[1:]])
xx2 = torch.min(x2[i], x2[order[1:]])
yy2 = torch.min(y2[i], y2[order[1:]])
w = torch.clamp(xx2 - xx1, min=0)
h = torch.clamp(yy2 - yy1, min=0)
inter = w * h
iou = inter / (areas[i] + areas[order[1:]] - inter)
# Keep the boxes with IoU less than the threshold
inds = torch.where(iou <= iou_threshold)[0]
order = order[inds + 1]
return torch.tensor(keep, dtype=torch.long)
代码工作原理: