• 多分类中混淆矩阵的TP,TN,FN,FP计算


    关于混淆矩阵,各位可以在这里了解:混淆矩阵细致理解_夏天是冰红茶的博客-CSDN博客

    上一篇中我们了解了混淆矩阵,并且进行了类定义,那么在这一节中我们将要对其进行扩展,在多分类中,如何去计算TP,TN,FN,FP。

    原理推导

    这里以三分类为例,这里来看看TP,TN,FN,FP是怎么分布的。

    类别1的标签:

    类别2的标签:

    类别3的标签:

    这样我们就能知道了混淆矩阵的对角线就是TP

    TP = torch.diag(h)

     假正例(FP)是模型错误地将负类别样本分类为正类别的数量

    FP = torch.sum(h, dim=1) - TP

    假负例(FN)是模型错误地将正类别样本分类为负类别的数量

    FN = torch.sum(h, dim=0) - TP

    最后用总数减去除了 TP 的其他三个元素之和得到 TN

    TN = torch.sum(h) - (torch.sum(h, dim=0) + torch.sum(h, dim=1) - TP)

    逻辑验证

    这里借用上一篇的例子,假如我们这个混淆矩阵是这样的:

    tensor([[2, 0, 0],
                [0, 1, 1],
                [0, 2, 0]])

    为了方便讲解,这里我们对其进行一个简单的编号,即0—8:

    012
    345
    678

    torch.sum(h, dim=1) 可得 tensor([2., 2., 2.]) , torch.sum(h, dim=0) 可得 tensor([2., 3., 1.]) 。

    •  TP:   tensor([2., 1., 0.]) 
    •  FN:   tensor([0., 1., 2.]) 
    •  TN:   tensor([4., 2., 3.]) 
    •  FP:   tensor([0., 2., 1.])

    我们先来看看TP的构成,对应着矩阵的对角线2,1,0;FP在类别1中占3,6号位,在类别2中占1,7号位,在类别3中占2,5号位,加起来即为0,1,2;TN在类别1中占4,5,7,8号位,在类别2中占边角位,在类别3中占0,1,3,4号位,加起来即为4,2,3;FN在类别1中占1,2号位,在类别2中占3,5号位,在类别3中占6,7号位,加起来即为0,2,1。

    补充类定义

    1. import torch
    2. import numpy as np
    3. class ConfusionMatrix(object):
    4. def __init__(self, num_classes):
    5. self.num_classes = num_classes
    6. self.mat = None
    7. def update(self, t, p):
    8. n = self.num_classes
    9. if self.mat is None:
    10. # 创建混淆矩阵
    11. self.mat = torch.zeros((n, n), dtype=torch.int64, device=t.device)
    12. with torch.no_grad():
    13. # 寻找GT中为目标的像素索引
    14. k = (t >= 0) & (t < n)
    15. # 统计像素真实类别t[k]被预测成类别p[k]的个数
    16. inds = n * t[k].to(torch.int64) + p[k]
    17. self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)
    18. def reset(self):
    19. if self.mat is not None:
    20. self.mat.zero_()
    21. @property
    22. def ravel(self):
    23. """
    24. 计算混淆矩阵的TN, FP, FN, TP
    25. """
    26. h = self.mat.float()
    27. n = self.num_classes
    28. if n == 2:
    29. TP, FN, FP, TN = h.flatten()
    30. return TP, FN, FP, TN
    31. if n > 2:
    32. TP = h.diag()
    33. FN = h.sum(dim=1) - TP
    34. FP = h.sum(dim=0) - TP
    35. TN = torch.sum(h) - (torch.sum(h, dim=0) + torch.sum(h, dim=1) - TP)
    36. return TP, FN, FP, TN
    37. def compute(self):
    38. """
    39. 主要在eval的时候使用,你可以调用ravel获得TN, FP, FN, TP, 进行其他指标的计算
    40. 计算全局预测准确率(混淆矩阵的对角线为预测正确的个数)
    41. 计算每个类别的准确率
    42. 计算每个类别预测与真实目标的iou,IoU = TP / (TP + FP + FN)
    43. """
    44. h = self.mat.float()
    45. acc_global = torch.diag(h).sum() / h.sum()
    46. acc = torch.diag(h) / h.sum(1)
    47. iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))
    48. return acc_global, acc, iu
    49. def __str__(self):
    50. acc_global, acc, iu = self.compute()
    51. return (
    52. 'global correct: {:.1f}\n'
    53. 'average row correct: {}\n'
    54. 'IoU: {}\n'
    55. 'mean IoU: {:.1f}').format(
    56. acc_global.item() * 100,
    57. ['{:.1f}'.format(i) for i in (acc * 100).tolist()],
    58. ['{:.1f}'.format(i) for i in (iu * 100).tolist()],
    59. iu.mean().item() * 100)

    我在代码中添加了属性修饰器,以便我们可以直接的进行调用,并且也考虑到了二分类与多分类不同的情况。

    性能指标

    关于这些指标在网上有很多介绍,这里就不细讲了

    1. class ModelIndex():
    2. def __init__(self,TP, FN, FP, TN, e=1e-5):
    3. self.TN = TN
    4. self.FP = FP
    5. self.FN = FN
    6. self.TP = TP
    7. self.e = e
    8. def Precision(self):
    9. """精确度衡量了正类别预测的准确性"""
    10. return self.TP / (self.TP + self.FP + self.e)
    11. def Recall(self):
    12. """召回率衡量了模型对正类别样本的识别能力"""
    13. return self.TP / (self.TP + self.FN + self.e)
    14. def IOU(self):
    15. """表示模型预测的区域与真实区域之间的重叠程度"""
    16. return self.TP / (self.TP + self.FP + self.FN + self.e)
    17. def F1Score(self):
    18. """F1分数是精确度和召回率的调和平均数"""
    19. p = self.Precision()
    20. r = self.Recall()
    21. return (2*p*r) / (p + r + self.e)
    22. def Specificity(self):
    23. """特异性是指模型在负类别样本中的识别能力"""
    24. return self.TN / (self.TN + self.FP + self.e)
    25. def Accuracy(self):
    26. """准确度是模型正确分类的样本数量与总样本数量之比"""
    27. return (self.TP + self.TN) / (self.TP + self.TN + self.FP + self.FN + self.e)
    28. def FP_rate(self):
    29. """False Positive Rate,假阳率是模型将负类别样本错误分类为正类别的比例"""
    30. return self.FP / (self.FP + self.TN + self.e)
    31. def FN_rate(self):
    32. """False Negative Rate,假阴率是模型将正类别样本错误分类为负类别的比例"""
    33. return self.FN / (self.FN + self.TP + self.e)
    34. def Qualityfactor(self):
    35. """品质因子综合考虑了召回率和特异性"""
    36. r = self.Recall()
    37. s = self.Specificity()
    38. return r+s-1

    参考文章:多分类中TP/TN/FP/FN的计算_Hello_Chan的博客-CSDN博客 

  • 相关阅读:
    Redis常用指令之string、list、set、zset、hash
    Web自动化测试 —— PageObject设计模式!
    安全生产知识竞赛活动小程序界面分享
    2022年7月国产数据库大事记-墨天轮
    家政类小程序开发,互联网+家政系统,全套家政系统开发方案
    Vue学习第22天——Vuex安装使用详解及案例练习(彻底搞懂vuex)
    LeetCode | 循环队列的爱情【恋爱法则——环游世界】
    【路径遍历漏洞】查找、利用、预防
    docker搭建nginx
    tomcat必要的配置
  • 原文地址:https://blog.csdn.net/m0_62919535/article/details/132926719