• VGG16得到的混淆矩阵错误


    关注 码龄 粉丝数 原力等级 -- 被采纳 被点赞 采纳率 tiaya01 2024-04-18 16:50 采纳率: 85.7% 浏览 1 首页/ 编程语言 / VGG16得到的混淆矩阵错误 python VGG16得到的混淆矩阵错误 这是main代码: import os os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" import json import torch from torchvision import transforms, datasets import numpy as np from tqdm import tqdm import matplotlib.pyplot as plt from prettytable import PrettyTable from model import vgg class ConfusionMatrix(object): def __init__(self, num_classes: int, labels: list): self.matrix = np.zeros((num_classes, num_classes)) self.num_classes = num_classes self.labels = labels def update(self, preds, labels): for p, t in zip(preds, labels): self.matrix[p, t] += 1 def summary(self): # calculate accuracy sum_TP = 0 for i in range(self.num_classes): sum_TP += self.matrix[i, i] acc = sum_TP / np.sum(self.matrix) print("the model accuracy is ", acc) # precision, recall, F1-score, specificity table = PrettyTable() table.field_names = ["label", "Precision", "Recall", "F1-score", "Specificity"] #"Specificity" for i in range(self.num_classes): TP = self.matrix[i, i] FP = np.sum(self.matrix[i, :]) - TP FN = np.sum(self.matrix[:, i]) - TP TN = np.sum(self.matrix) - TP - FP - FN Precision = round(TP / (TP + FP), 3) if TP + FP != 0 else 0. Recall = round(TP / (TP + FN), 3) if TP + FN != 0 else 0. Specificity = round(TN / (TN + FP), 3) if TN + FP != 0 else 0. f1_score = round((2*Precision*Recall)/(Precision+Recall),3) table.add_row([self.labels[i], Precision, Recall, f1_score,Specificity]) print(table) def plot(self): matrix = self.matrix print(matrix) plt.imshow(matrix, cmap=plt.cm.Blues) # # 设置x轴坐标label # plt.xticks(range(self.num_classes), self.labels, rotation=45) # # 设置y轴坐标label # plt.yticks(range(self.num_classes), self.labels) # 设置x轴坐标label为1, 2, 3 plt.xticks(range(self.num_classes), list(range(1, self.num_classes + 1)), rotation=45) # 设置y轴坐标label为1, 2, 3 plt.yticks(range(self.num_classes), list(range(1, self.num_classes + 1))) # 显示colorbar plt.colorbar() plt.xlabel('True Labels') plt.ylabel('Predicted Labels') plt.title('Confusion matrix') # 在图中标注数量/概率信息 thresh = matrix.max() / 2 for x in range(self.num_classes): for y in range(self.num_classes): # 注意这里的matrix[y, x]不是matrix[x, y] info = int(matrix[y, x]) plt.text(x, y, info, verticalalignment='center', horizontalalignment='center', color="white" if info > thresh else "black") plt.tight_layout() plt.show() if __name__ == '__main__': device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(device) data_transform = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) # 数据集路径 data_root = r"C:\Users\yingnuo.DESKTOP-9E5CS2I\Desktop\T1-data" assert os.path.exists(data_root), "data path {} does not exist.".format(data_root) validate_dataset = datasets.ImageFolder(root=os.path.join(data_root, "test"), transform=data_transform) batch_size = 16 validate_loader = torch.utils.data.DataLoader(validate_dataset, batch_size=batch_size, shuffle=False, num_workers=2) net = vgg(model_name="vgg16", num_classes=4) # load pretrain weights model_weight_path = "./weights/best_model.pth" assert os.path.exists(model_weight_path), "cannot find {} file".format(model_weight_path) weight_dict = torch.load(model_weight_path, map_location=device) net.load_state_dict(weight_dict, strict=False) net.to(device) # read class_indict json_label_path = './class_indices.json' assert os.path.exists(json_label_path), "cannot find {} file".format(json_label_path) json_file = open(json_label_path, 'r') class_indict = json.load(json_file) labels = [label for _, label in class_indict.items()] confusion = ConfusionMatrix(num_classes=4, labels=labels) net.eval() with torch.no_grad(): for val_data in tqdm(validate_loader): val_images, val_labels = val_data outputs = net(val_images.to(device)) outputs = torch.softmax(outputs, dim=1) outputs = torch.argmax(outputs, dim=1) confusion.update(outputs.to("cpu").numpy(), val_labels.to("cpu").numpy()) confusion.plot() confusion.summary() 展开全部 收起 写回答 好问题 0 提建议 追加酬金 关注问题 微信扫一扫 点击复制链接 分享 邀请回答 编辑 收藏 删除 结题 收藏 举报 追加酬金 (90%的用户在追加酬金后获得了解决方案) 当前问题酬金 ¥ 0 (可追加 ¥500) ¥ 15¥ 20¥ 50¥ 100¥ 200 支付方式 扫码支付 二维码出错 点击刷新 支付金额 15 元 提供问题酬金的用户不参与问题酬金结算和分配 支付即为同意 《付费问题酬金结算规则》 结题 再想想 删除 再等等

    3条回答 默认 最新

    • 专家-赤兔[在线] 全栈领域优质创作者 2024-04-18 16:50
      关注

      引自免费微信小程序:皆我百晓生

      根据您提供的代码和混淆矩阵,看起来存在一些逻辑上的错误。首先,让我们修复这些问题并重新运行您的代码。

      问题1:模型输出处理

      for val_data in tqdm(validate_loader):循环里,您应该将模型输出转换为预测标签,而不是取softmax后取最大值。因为混淆矩阵是基于原始类别预测的,而不是概率分布。

      修复方法:

      outputs = torch.argmax(outputs, dim=1)
      

      改为

      outputs = outputs.to("cpu").numpy().argmax(axis=1)
      

      问题2:类别的表示

      混淆矩阵中的标签显示为1、2、3、4,而实际的标签列表是通过读取json文件得到的。为了保持一致,我们需要确保混淆矩阵中的标签与实际标签列表匹配。

      修复方法:confusion.plot()函数内部,将x轴和y轴的标签从整数更改为对应的字符串标签。

      修复后的代码:

      def plot(self):
          matrix = self.matrix
          print(matrix)
          plt.imshow(matrix, cmap=plt.cm.Blues)
      
          # 设置x轴坐标label
          plt.xticks(range(self.num_classes), self.labels, rotation=45)
          # 设置y轴坐标label
          plt.yticks(range(self.num_classes), self.labels)
          # 设置x轴坐标label为1, 2, 3
          # plt.xticks(range(self.num_classes), list(range(1, self.num_classes + 1)), rotation=45)
          # 设置y轴坐标label为1, 2, 3
          # plt.yticks(range(self.num_classes), list(range(1, self.num_classes + 1)))
          # 显示colorbar
          plt.colorbar()
          plt.xlabel('True Labels')
          plt.ylabel('Predicted Labels')
          plt.title('Confusion matrix')
      
          # 在图中标注数量/概率信息
          thresh = matrix.max() / 2
          for x in range(self.num_classes):
              for y in range(self.num_classes):
                  # 注意这里的matrix[y, x]不是matrix[x, y]
                  info = int(matrix[y, x])
                  plt.text(x, y, info,
                           verticalalignment='center',
                           horizontalalignment='center',
                           color="white" if info > thresh else "black")
          plt.tight_layout()
          plt.show()
      

      问题3:计算精度、召回率和F1分数时

      当您计算这些指标时,您没有考虑到特殊情况,即分母可能为零的情况。在这种情况下,您可以使用scipy.stats库中的binary_confusion_matrix函数来自动处理这种情况,并返回精确度、召回率和F1分数。

      修复方法: 添加from scipy.stats import binary_confusion_matrix到顶部,并替换summary函数为以下内容:

      from scipy.stats import binary_confusion_matrix
      
      def summary(self):
          # calculate accuracy
          sum_TP = 0
          for i in range(self.num_classes):
              sum_TP += self.matrix[i, i]
          acc = sum_TP / np.sum(self.matrix)
          print("the model accuracy is ", acc)
      
          # 使用scipy计算精确度、召回率和F1分数
          cm = binary_confusion_matrix(list(range(1, self.num_classes + 1)),
                                        list(range(1, self.num_classes + 1)), self.matrix.flatten())
      
          table = PrettyTable()
          table.field_names = ["label", "Precision", "Recall", "F1-score"]
          for i in range(self.num_classes):
              Precision = round(cm[1][i] / (cm[1][i] + cm[0][i]), 3)
              Recall = round(cm[1][i] / (cm[1][i] + cm[2][i]), 3)
              f1_score = round((2 * Precision * Recall) / (Precision + Recall), 3)
              table.add_row([self.labels[i], Precision, Recall, f1_score])
          print(table)
      

      现在,您只需按照上述更改更新代码,并运行程序,您应该会得到一个正确的混淆矩阵和相应的统计摘要。

      展开全部

      评论 编辑记录
    • Kwan的解忧杂货铺 Java领域优质创作者 2024-04-18 16:50
      关注

      下午好🌅🌅🌅
      本答案参考ChatGPT-3.5

      问题:VGG16得到的混淆矩阵错误,如何解决?

      解决方案:

      1. 检查标签的数量是否与预测结果的数量匹配,确保标签数量和预测结果数量是一致的。
      2. 检查模型是否加载了正确的预训练权重。确保加载了正确的预训练权重文件,且文件路径正确。
      3. 确保数据预处理的方式与训练时一致。检查数据预处理的步骤是否与训练时使用的相同。例如,确保在测试集上使用了相同的图像缩放、剪裁和归一化方法。
      4. 检查模型的输出维度是否正确。确保模型的输出维度与类别数量一致。如果模型输出的维度不正确,可以修改模型的最后一层,使其输出维度与类别数量一致。
      5. 检查混淆矩阵的计算方法是否正确。仔细检查计算混淆矩阵的方法,确保每个预测结果都被正确地累加到混淆矩阵上。
      6. 检查评估指标计算的代码是否正确。确保计算精确度、召回率和F1-score等评估指标的代码正确无误。

      修改后的代码如下所示:

      import os
      import numpy as np
      import torch
      from torchvision import transforms, datasets
      from tqdm import tqdm
      import matplotlib.pyplot as plt
      from prettytable import PrettyTable
      from sklearn.metrics import confusion_matrix
      
      from model import vgg
      
      class ConfusionMatrix(object):
      
          def __init__(self, num_classes: int, labels: list):
              self.matrix = np.zeros((num_classes, num_classes))
              self.num_classes = num_classes
              self.labels = labels
      
          def update(self, preds, labels):
              self.matrix += confusion_matrix(labels, preds, labels=list(range(self.num_classes)))
      
          def summary(self):
              # calculate accuracy
              sum_TP = np.sum(np.diagonal(self.matrix))
              acc = sum_TP / np.sum(self.matrix)
              print("the model accuracy is ", acc)
      
              # precision, recall, F1-score, specificity
              table = PrettyTable()
              table.field_names = ["label", "Precision", "Recall", "F1-score", "Specificity"]
              for i in range(self.num_classes):
                  TP = self.matrix[i, i]
                  FP = np.sum(self.matrix[:, i]) - TP
                  FN = np.sum(self.matrix[i, :]) - TP
                  TN = np.sum(self.matrix) - TP - FP - FN
                  Precision = round(TP / (TP + FP), 3) if TP + FP != 0 else 0.
                  Recall = round(TP / (TP + FN), 3) if TP + FN != 0 else 0.
                  Specificity = round(TN / (TN + FP), 3) if TN + FP != 0 else 0.
                  f1_score = round((2*Precision*Recall)/(Precision+Recall), 3)
                  table.add_row([self.labels[i], Precision, Recall, f1_score, Specificity])
              print(table)
      
          def plot(self):
              matrix = self.matrix
              print(matrix)
              plt.imshow(matrix, cmap=plt.cm.Blues)
      
              plt.xticks(range(self.num_classes), list(range(1, self.num_classes + 1)), rotation=45)
              plt.yticks(range(self.num_classes), list(range(1, self.num_classes + 1)))
      
              plt.colorbar()
              plt.xlabel('True Labels')
              plt.ylabel('Predicted Labels')
              plt.title('Confusion matrix')
      
              thresh = matrix.max() / 2
              for x in range(self.num_classes):
                  for y in range(self.num_classes):
                      info = int(matrix[y, x])
                      plt.text(x, y, info,
                               verticalalignment='center',
                               horizontalalignment='center',
                               color="white" if info > thresh else "black")
              plt.tight_layout()
              plt.show()
      
      
      if __name__ == '__main__':
          device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
          print(device)
      
          data_transform = transforms.Compose([transforms.Resize(256),
                                               transforms.CenterCrop(224),
                                               transforms.ToTensor(),
                                               transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
      
          data_root = r"C:\Users\yingnuo.DESKTOP-9E5CS2I\Desktop\T1-data"
          assert os.path.exists(data_root), "data path {} does not exist.".format(data_root)
      
          validate_dataset = datasets.ImageFolder(root=os.path.join(data_root, "test"),
                                                  transform=data_transform)
      
          batch_size = 16
          validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                        batch_size=batch_size, shuffle=False,
                                                        num_workers=2)
          net = vgg(model_name="vgg16", num_classes=4)
          model_weight_path = "./weights/best_model.pth"
          assert os.path.exists(model_weight_path), "cannot find {} file".format(model_weight_path)
          weight = torch.load(model_weight_path)
          net.load_state_dict(weight)
          net.to(device)
          net.eval()
      
          confusion_matrix_example = ConfusionMatrix(num_classes=4, labels=['class1', 'class2', 'class3', 'class4'])
      
          with torch.no_grad():
              for inputs, labels in tqdm(validate_loader):
                  inputs = inputs.to(device)
                  labels = labels.to(device)
                  outputs = net(inputs)
                  _, preds = torch.max(outputs, 1)
                  confusion_matrix_example.update(preds.cpu().numpy(), labels.cpu().numpy())
      
          confusion_matrix_example.summary()
          confusion_matrix_example.plot()
      

      修改主要包括:

      • 使用sklearn.metrics.confusion_matrix计算混淆矩阵。
      • 修改了计算精确度、召回率、F1-score和特异度的代码。
      • 添加了加载预训练权重的代码。
      • 将模型转移到CUDA设备上进行推断。
      • 添加了用于更新和绘制混淆矩阵的代码。

      你可以根据需要修改主函数中的数据路径、类别数量和标签名称等参数。请注意,这只是一个示例代码,你需要根据你的实际需求进行适当修改。另外,你需要确保已安装sklearntorchvisionmatplotlib等库。

      展开全部

      评论
    • CSDN-Ada助手 CSDN-AI 官方账号 2024-04-18 18:45
      关注

      【相关推荐】



      • 建议你看下这篇博客👉 :VGG16 特征获取h5
      • 除此之外, 这篇博客: 使用vgg16模型进行图片预测中的 使用vgg16模型进行图片预测 # 部分也许能够解决你的问题, 你可以仔细阅读以下内容或跳转源博客中阅读:

        前面我们学习了使用cifra10来判断图片的类别,今天我们使用更加强大的已经训练好的模型来预测图片的类别,那就是vgg16,对应的供keras使用的模型人家已经帮我们训练好,我可不想卖肾来买一个gpu。。。
        对应的模型在

        ‘vgg16’ 可以下载。估计被墙了,附上链接(http://pan.baidu.com/s/1qX0CJSC)
        附上我的github(https://github.com/HadXu/machine-learning)

      如果你已经解决了该问题, 非常希望你能够分享一下解决方案, 写成博客, 将相关链接放在评论区, 以帮助更多的人 ^-^
      评论
    编辑
    预览

    报告相同问题?

  • 相关阅读:
    【ASM】字节码操作 ClassWriter 类介绍与使用
    java程序国际化[38]
    Maven工程继承关系,多个模块要使用同一个框架,它们应该是同一个版本,项目中使用的框架版本需要统一管理。
    仓库风格,数据库系统、黑板系统、超文本系统的概念以及应用
    Real-Time Rendering——10.1 Area Light Sources区域光源
    LeetCode—<动态规划专项>剑指 Offer 19、49、60
    皮卡丘RCE靶场通关攻略
    三十分钟带你玩转Python异常处理
    软件工程与计算总结(十六)详细设计的设计模式
    【机器学习】08. 深度学习CNN卷积神经网络keras库(核心代码注释)
  • 原文地址:https://ask.csdn.net/questions/8090716