• pytorch中meter.ClassErrorMeter()使用方法


    PyTorchNet从TorchNet迁移而来,其中提供了很多有用的工具,例如meter,meter提供了一些轻量级的工具,帮助用户快速计算训练过程中的一些指标。

    AverageValueMeter能够计算所有数的平均值和标准差,同意几个epoch中损失的平均值。

    ClassErrorMeter能够计算每个epoch下的类别错误率,在分类任务经常使用。

    下面我们介绍下如何使用ClassErrorMeter()这个方法计算每个epoch的图像分类准确率,对于这个目的,我们可以通过定义变量,然后不断累加每个批次的数据,然后进行计算,但是现在有一个更好的工具,可以帮助我们实现这个操作。

    首先使用meter.ClassErrorMeter()实例化一个类,该类可以想成内部有一个集合,里面会保存一些数据,并定义一些方法能够对这些数据进行处理来满足我们的要求,说白了就是把我们正常计算指标的代码封装到一个类中。

    error_meter = meter.ClassErrorMeter()
    
    • 1

    我们只需要调用类的add函数不断将数据添加到其中即可,该函数有两个参数,分别是outputtarget,第一个参数是模型的softmax输出结果,第二个参数是对应的标签。

    error_meter.add(output.detach(), labels)
    
    • 1

    然后等一个epoch的所有结果全部填入其中之后,就可以使用error_meter.value获得结果

    error_meter.value()
    
    • 1

    但是注意一个问题,他计算的是错误率,如果想要正确率,那么需要用100减去它即可。

    而且还需要注意一个问题,当我们处理完一个epoch之后,需要清空当前的信息,只需要调用reset()即可。

    下面使用一个示例来说明如何使用:

    for epoch in range(20):
        model.train()
        for data in tqdm(train_loader):
            images, labels = data
            optimizer.zero_grad()
            output = model(images)
            loss = loss_function(output, labels)
            loss.backward()
            optimizer.step()
            
            loss_meter.add(loss.item())
            error_meter.add(output.detach(), labels)
    
         # 打印信息
        print("【EPOCH: 】%s" % str(epoch + 1))
        print("训练集损失为%s" % (str(loss_meter.mean)))
        print("训练集精度为%s" % (str(100 - error_meter.value()[0]) + '%'))
        loss_meter.reset()
        error_meter.reset()
        
        
        model.eval()
        for data in tqdm(val_loader):
            images, labels = data
            optimizer.zero_grad()
            output = model(images)
            loss = loss_function(output, labels)
            loss.backward()
            optimizer.step()
            
            loss_meter.add(loss.item())
            error_meter.add(output.detach(), labels)
    
        print("【EPOCH: 】%s" % str(epoch + 1))
        print("验证集损失为%s" % (str(loss_meter.mean)))
        print("验证集精度为%s" % (str(100 - error_meter.value()[0]) + '%'))
        loss_meter.reset()
        error_meter.reset()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
  • 相关阅读:
    4 | Nikto使用
    C++小程序——“靠谱”的预测器
    QFile和QDataStream二进制文件读写第三集
    mysql基于SSM的自习室管理系统毕业设计源码201524
    vscode下ssh免密登录linux服务器
    复制二叉树
    宝塔FTP提示:553 Can‘t open that file: Permission denied的解决方案
    STM32实战项目:从零打造GPS蓝牙自行车码表,掌握传感器、蓝牙、Flash存储等核心技术
    信贷风控拒绝客户的捞回策略详解
    java计算机毕业设计ssm养老管理系统-敬老院系统
  • 原文地址:https://blog.csdn.net/m0_47256162/article/details/127848579