• DETR纯代码分享(十)matcher.py(models)匈牙利匹配算法


    一、导入模块

    1. import torch
    2. from scipy.optimize import linear_sum_assignment
    3. from torch import nn
    4. from util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou

    这段代码导入了一些PyTorch和SciPy中的模块和函数,以及自定义模块中的一些函数。

    1. `import torch`: 导入PyTorch库,用于深度学习任务。

    2. `from scipy.optimize import linear_sum_assignment`: 从SciPy库中导入`linear_sum_assignment`函数,它用于解决线性求和分配问题,通常用于匈牙利算法,用于在最优的方式下分配任务。

    3. `from torch import nn`: 从PyTorch库中导入神经网络模块,`nn` 模块包含了构建神经网络层的类和函数。

    4. `from util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou`: 从自定义模块 `util.box_ops` 中导入 `box_cxcywh_to_xyxy` 和 `generalized_box_iou` 函数。这些函数可能是与处理边界框(bounding box)有关的工具函数,用于转换边界框坐标格式以及计算边界框之间的交并比(IoU)等操作。

    二、 HungarianMatcher 模块

    1. class HungarianMatcher(nn.Module):
    2. """This class computes an assignment between the targets and the predictions of the network
    3. For efficiency reasons, the targets don't include the no_object. Because of this, in general,
    4. there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
    5. while the others are un-matched (and thus treated as non-objects).
    6. """
    7. def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1):
    8. """Creates the matcher
    9. Params:
    10. cost_class: This is the relative weight of the classification error in the matching cost
    11. cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
    12. cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
    13. """
    14. super().__init__()
    15. self.cost_class = cost_class
    16. self.cost_bbox = cost_bbox
    17. self.cost_giou = cost_giou
    18. assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0"
    19. @torch.no_grad()
    20. def forward(self, outputs, targets):
    21. """ Performs the matching
    22. Params:
    23. outputs: This is a dict that contains at least these entries:
    24. "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
    25. "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
    26. targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
    27. "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
    28. objects in the target) containing the class labels
    29. "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
    30. Returns:
    31. A list of size batch_size, containing tuples of (index_i, index_j) where:
    32. - index_i is the indices of the selected predictions (in order)
    33. - index_j is the indices of the corresponding selected targets (in order)
    34. For each batch element, it holds:
    35. len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
    36. """
    37. bs, num_queries = outputs["pred_logits"].shape[:2]
    38. # We flatten to compute the cost matrices in a batch
    39. out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes]
    40. out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
    41. # Also concat the target labels and boxes
    42. tgt_ids = torch.cat([v["labels"] for v in targets])
    43. tgt_bbox = torch.cat([v["boxes"] for v in targets])
    44. # Compute the classification cost. Contrary to the loss, we don't use the NLL,
    45. # but approximate it in 1 - proba[target class].
    46. # The 1 is a constant that doesn't change the matching, it can be ommitted.
    47. cost_class = -out_prob[:, tgt_ids]
    48. # Compute the L1 cost between boxes
    49. cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
    50. # Compute the giou cost betwen boxes
    51. cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox))
    52. # Final cost matrix
    53. C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
    54. C = C.view(bs, num_queries, -1).cpu()
    55. sizes = [len(v["boxes"]) for v in targets]
    56. indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
    57. return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]

    这段代码定义了一个名为 HungarianMatcher 的PyTorch模块,该模块用于计算网络输出和目标之间的匹配(assignment)。

    这个模块主要用于目标检测任务中,其中网络输出(predictions)和目标(targets)是需要匹配的。

    1、__init__()函数
    1. class HungarianMatcher(nn.Module):
    2. """This class computes an assignment between the targets and the predictions of the network
    3. For efficiency reasons, the targets don't include the no_object. Because of this, in general,
    4. there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
    5. while the others are un-matched (and thus treated as non-objects).
    6. """
    7. def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1):
    8. """Creates the matcher
    9. Params:
    10. cost_class: This is the relative weight of the classification error in the matching cost
    11. cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
    12. cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
    13. """
    14. super().__init__()
    15. self.cost_class = cost_class
    16. self.cost_bbox = cost_bbox
    17. self.cost_giou = cost_giou
    18. assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0"

    这段代码定义了一个名为 HungarianMatcher 的PyTorch模块,用于执行目标检测中的匹配操作。以下是代码的详细解释:

    1. class HungarianMatcher(nn.Module)::定义了一个继承自 nn.Module 的Python类,表示匈牙利匹配器。

    2. 文档字符串(Docstring):这是类的注释,提供了对类的简要描述和用途。

    3. def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1)::初始化方法,用于创建匈牙利匹配器的实例。它接受三个可选参数,分别是:

      • cost_class:分类错误在匹配成本中的相对权重,默认为1。
      • cost_bbox:边界框坐标错误在匹配成本中的相对权重,默认为1。
      • cost_giou:GIOU损失在匹配成本中的相对权重,默认为1。
    4. super().__init__():调用父类的构造函数以正确初始化模块。

    5. self.cost_class, self.cost_bbox, self.cost_giou:将传入的三个参数值存储在模块的实例变量中,以便在后续的计算中使用。

    6. assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0":断言语句,用于确保三个成本权重中至少有一个不为零。如果三个成本都为零,将引发AssertionError异常,以防止不合理的输入。

    这个类的主要目的是在目标检测任务中,根据网络的预测结果和目标(ground-truth)之间执行最优匹配,以便计算损失和优化目标检测模型。成本权重用于调整分类错误、边界框坐标错误和GIOU损失之间的相对重要性,以满足特定任务的需求。匈牙利匹配算法用于执行最优匹配,使得每个预测与一个目标(或未匹配的情况)关联,以便计算损失。

    2、forward()函数
    1. @torch.no_grad()
    2. def forward(self, outputs, targets):
    3. """ Performs the matching
    4. Params:
    5. outputs: This is a dict that contains at least these entries:
    6. "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
    7. "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
    8. targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
    9. "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
    10. objects in the target) containing the class labels
    11. "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
    12. Returns:
    13. A list of size batch_size, containing tuples of (index_i, index_j) where:
    14. - index_i is the indices of the selected predictions (in order)
    15. - index_j is the indices of the corresponding selected targets (in order)
    16. For each batch element, it holds:
    17. len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
    18. """
    19. bs, num_queries = outputs["pred_logits"].shape[:2]
    20. # We flatten to compute the cost matrices in a batch
    21. out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes]
    22. out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
    23. # Also concat the target labels and boxes
    24. tgt_ids = torch.cat([v["labels"] for v in targets])
    25. tgt_bbox = torch.cat([v["boxes"] for v in targets])
    26. # Compute the classification cost. Contrary to the loss, we don't use the NLL,
    27. # but approximate it in 1 - proba[target class].
    28. # The 1 is a constant that doesn't change the matching, it can be ommitted.
    29. cost_class = -out_prob[:, tgt_ids]
    30. # Compute the L1 cost between boxes
    31. cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
    32. # Compute the giou cost betwen boxes
    33. cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox))
    34. # Final cost matrix
    35. C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
    36. C = C.view(bs, num_queries, -1).cpu()
    37. sizes = [len(v["boxes"]) for v in targets]
    38. indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
    39. return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]

    这个 `forward` 方法是 `HungarianMatcher` 类的主要方法,用于执行匈牙利匹配操作,将预测与目标进行匹配。下面是代码中每行的详细解释:

    1. `@torch.no_grad()`:这是一个装饰器,用于将下面的方法调用设置为无需梯度。这是因为匈牙利匹配操作不需要进行梯度计算。

    2. `def forward(self, outputs, targets):`:前向传播方法,用于执行匹配操作。接受两个参数:
       - `outputs`:一个字典,包含以下至少两个条目:
         - "pred_logits":形状为 [batch_size, num_queries, num_classes] 的张量,包含分类的 logits。
         - "pred_boxes":形状为 [batch_size, num_queries, 4] 的张量,包含预测的边界框坐标。
       - `targets`:一个目标列表,每个目标都是一个字典,包含以下两个条目:
         - "labels":形状为 [num_target_boxes] 的张量,包含目标类别标签。
         - "boxes":形状为 [num_target_boxes, 4] 的张量,包含目标边界框坐标。

    3. `bs, num_queries = outputs["pred_logits"].shape[:2]`:获取批量大小(batch size)和查询数量(num_queries)。

    4. `out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1)`:将分类 logits 平铺并进行 softmax 操作,以计算预测的类别概率。结果形状为 [batch_size * num_queries, num_classes]。

    5. `out_bbox = outputs["pred_boxes"].flatten(0, 1)`:将预测的边界框坐标平铺,形状为 [batch_size * num_queries, 4]。

    6. `tgt_ids = torch.cat([v["labels"] for v in targets])`:将目标中的类别标签连接成一个张量,形状为 [总目标边界框数]。

    7. `tgt_bbox = torch.cat([v["boxes"] for v in targets])`:将目标中的边界框坐标连接成一个张量,形状为 [总目标边界框数, 4]。

    8. `cost_class = -out_prob[:, tgt_ids]`:计算分类成本,即预测类别与目标类别之间的损失。这里使用了负对数似然的近似计算。

    9. `cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)`:计算边界框坐标成本,即预测边界框坐标与目标边界框坐标之间的 L1 距离。

    10. `cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox))`:计算 GIOU(Generalized IoU)成本,即预测边界框与目标边界框之间的 GIOU 损失。

    11. `C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou`:将分类、边界框坐标和GIOU成本加权组合,得到最终的匹配成本矩阵C。

    12. `C = C.view(bs, num_queries, -1).cpu()`:将成本矩阵C重新形状为 [batch_size, num_queries, 总目标边界框数],并将其移到CPU上。

    13. `sizes = [len(v["boxes"]) for v in targets]`:获取每个目标中的边界框数量,存储在列表中。

    14. `indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]`:对每个批次中的成本矩阵执行线性求和分配,以找到最佳匹配。

    15. `return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]`:返回一个列表,其中包含了每个批次中的匹配结果。每个匹配结果是一个元组,包含两个张量,分别表示选定的预测索引和相应的目标索引。匹配数量等于最小的查询数量和目标边界框数量。

    三、build_matcher ()函数

    1. def build_matcher(args):
    2. return HungarianMatcher(cost_class=args.set_cost_class, cost_bbox=args.set_cost_bbox, cost_giou=args.set_cost_giou)

    build_matcher 函数用于构建一个 HungarianMatcher 类的实例,根据传入的参数配置匹配器的成本项。以下是这个函数的实现:

    1. def build_matcher(args)::定义了一个名为 build_matcher 的函数,接受一个参数 args,用于配置匹配器的成本项。

    2. return HungarianMatcher(cost_class=args.set_cost_class, cost_bbox=args.set_cost_bbox, cost_giou=args.set_cost_giou):创建并返回一个 HungarianMatcher 类的实例。在创建实例时,根据传入的 args 参数来设置成本项,这些成本项包括:

      • cost_class:分类错误的成本(类别损失的权重)。
      • cost_bbox:边界框坐标错误的成本(边界框坐标损失的权重)。
      • cost_giou:GIOU 损失的成本(GIOU 损失的权重)。

    这样,build_matcher 函数可以根据传入的参数创建并配置一个匹配器,并将其返回供后续使用。

  • 相关阅读:
    自定义表单模型小程序源码系统 带完整的部署教程
    适配器模式:转换接口,无缝对接不同系统
    1.数据类型
    idea 的复盘历史查看操控记录
    c++ 智能指针使用注意事项及解决方案
    HTML5:七天学会基础动画网页6
    【填坑】Error could not open `CProgram FilesJavajre1.8.0_202libamd64jvm.cfg‘
    定时任务(二)
    k8s集群配置
    Alkyne-PEG-OH 炔烃PEG羟基Alkyne-PEG-OH 炔烃PEG羟基
  • 原文地址:https://blog.csdn.net/sinat_41942180/article/details/132909544