• 使用Torchmetrics快速进行验证指标的计算


    TorchMetrics可以为我们提供一种简单、干净、高效的方式来处理验证指标。TorchMetrics提供了许多现成的指标实现,如Accuracy, Dice, F1 Score, Recall, MAE等等,几乎最常见的指标都可以在里面找到。torchmetrics目前已经包好了80+任务评价指标。

    TorchMetrics安装也非常简单,只需要PyPI安装最新版本:

     pip install torchmetrics
    
    • 1

    基本流程介绍

    在训练时我们都是使用微批次训练,对于TorchMetrics也是一样的,在一个批次前向传递完成后将目标值Y和预测值Y_PRED传递给torchmetrics的度量对象,度量对象会计算批次指标并保存它(在其内部被称为state)。

    当所有的批次完成时(也就是训练的一个Epoch完成),我们就可以从度量对象返回最终结果(这是对所有批计算的结果)。这里的每个度量对象都是从metric类继承,它包含了4个关键方法:

    • metric.forward(pred,target) - 更新度量状态并返回当前批次上计算的度量结果。如果您愿意,也可以使用metric(pred, target),没有区别。
    • metric.update(pred,target) - 与forward相同,但是不会返回计算结果,相当于是只将结果存入了state。如果不需要在当前批处理上计算出的度量结果,则优先使用这个方法,因为他不计算最终结果速度会很快。
    • metric.compute() - 返回在所有批次上计算的最终结果。也就是说其实forward相当于是update+compute。
    • metric.reset() - 重置状态,以便为下一个验证阶段做好准备。

    也就是说:在我们训练的当前批次,获得了模型的输出后可以forward或update(建议使用update)。在批次完成后,调用compute以获取最终结果。最后,在验证轮次(Epoch)或者启用新的轮次进行训练时您调用reset重置状态指标

    例如下面的代码:

     import torch
     import torchmetrics
     
     device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
     model = YourModel().to(device)
     metric = torchmetrics.Accuracy()
     
     for batch_idx, (data, target) in enumerate(val_dataloader):
         data, target = data.to(device), target.to(device)
         output = model(data)
         # metric on current batch
         batch_acc = metric.update(preds, target)
         print(f"Accuracy on batch {i}: {batch_acc}")
     
     # metric on all batches using custom accumulation
     val_acc = metric.compute()
     print(f"Accuracy on all data: {val_acc}")
     
     # Resetting internal state such that metric is ready for new data
     metric.reset()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20

    MetricCollection

    在上面的示例中,使用了单个指标进行计算,但一般情况下可能会包含多个指标。Torchmetrics提供了MetricCollection可以将多个指标包装成单个可调用类,其接口与上面的基本用法相同。这样我们就无需单独处理每个指标。

    代码如下:

     import torch
     from torchmetrics import MetricCollection, Accuracy, Precision, Recall
     
     device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
     model = YourModel().to(device)
     # collection of all validation metrics
     metric_collection = MetricCollection({
         'acc': Accuracy(),
         'prec': Precision(num_classes=10, average='macro'),
         'rec': Recall(num_classes=10, average='macro')
     })
     
     for batch_idx, (data, target) in enumerate(val_dataloader):
         data, target = data.to(device), target.to(device)
         output = model(data)
         batch_metrics = metric_collection.forward(preds, target)
         print(f"Metrics on batch {i}: {batch_metrics}")
     
     val_metrics = metric_collection.compute()
     print(f"Metrics on all data: {val_metrics}")
     metric.reset()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21

    也可以使用列表而不是字典,但是使用字典会更加清晰。

    自定义指标

    虽然Torchmetrics包含了很多常见的指标,但是有时我们还需要自己定义一些不常用的特定指标。我们只需要继承 Metric 类并且实现 updatecomputing 方法就可以了,另外就是需要在类初始化的时候使用self.add_state(state_name, default)来初始化我们的对象。

    代码也很简单:

     import torch
     import torchmetrics
     
     class MyAccuracy(Metric):
         def __init__(self, delta):
             super().__init__()
             # to count the correct predictions
             self.add_state('corrects', default=torch.tensor(0))
             # to count the total predictions
             self.add_state('total', default=torch.tensor(0))
     
         def update(self, preds, target):
             # update correct predictions count
             self.correct += torch.sum(preds == target)
             # update total count, numel() returns the total number of elements 
             self.total += target.numel()
     
         def compute(self):
             # final computation
             return self.correct / self.total
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20

    总结

    就是这样,Torchmetrics为我们指标计算提供了非常简单快速的处理方式,如果你想更多的了解它的用法,请参考官方文档:

    https://avoid.overfit.cn/post/bdedfe4229e04da49049c4e7d56152d1

    作者:Mattia Gatti

  • 相关阅读:
    opencv 使用DNN进行物体分类
    Android kotlin自定义圆形菜单的功能实现
    『无为则无心』Python基础 — 62、Python中自定义迭代器
    【技巧】Windows 下安装 ES 报错:Permission denied
    AUTOSAR汽车电子嵌入式编程精讲300篇-汽车 CAN FD 总线应用研究(中)
    2023计算机毕业设计SSM最新选题之java企业物资管理系统h3109
    vue3快速入门-生命周期
    window文件夹下python脚本实现批量删除无法预览的图片
    Python基础
    unity自动寻路
  • 原文地址:https://blog.csdn.net/m0_46510245/article/details/126657917