关注 码龄 粉丝数 原力等级 -- 被采纳 被点赞 采纳率 tiaya01 2024-04-18 16:50
采纳率: 85.7%
浏览 1 首页/
编程语言
/ VGG16得到的混淆矩阵错误 python VGG16得到的混淆矩阵错误
这是main代码:
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
import json
import torch
from torchvision import transforms, datasets
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from prettytable import PrettyTable
from model import vgg
class ConfusionMatrix(object):
def __init__(self, num_classes: int, labels: list):
self.matrix = np.zeros((num_classes, num_classes))
self.num_classes = num_classes
self.labels = labels
def update(self, preds, labels):
for p, t in zip(preds, labels):
self.matrix[p, t] += 1
def summary(self):
# calculate accuracy
sum_TP = 0
for i in range(self.num_classes):
sum_TP += self.matrix[i, i]
acc = sum_TP / np.sum(self.matrix)
print("the model accuracy is ", acc)
# precision, recall, F1-score, specificity
table = PrettyTable()
table.field_names = ["label", "Precision", "Recall", "F1-score", "Specificity"] #"Specificity"
for i in range(self.num_classes):
TP = self.matrix[i, i]
FP = np.sum(self.matrix[i, :]) - TP
FN = np.sum(self.matrix[:, i]) - TP
TN = np.sum(self.matrix) - TP - FP - FN
Precision = round(TP / (TP + FP), 3) if TP + FP != 0 else 0.
Recall = round(TP / (TP + FN), 3) if TP + FN != 0 else 0.
Specificity = round(TN / (TN + FP), 3) if TN + FP != 0 else 0.
f1_score = round((2*Precision*Recall)/(Precision+Recall),3)
table.add_row([self.labels[i], Precision, Recall, f1_score,Specificity])
print(table)
def plot(self):
matrix = self.matrix
print(matrix)
plt.imshow(matrix, cmap=plt.cm.Blues)
# # 设置x轴坐标label
# plt.xticks(range(self.num_classes), self.labels, rotation=45)
# # 设置y轴坐标label
# plt.yticks(range(self.num_classes), self.labels)
# 设置x轴坐标label为1, 2, 3
plt.xticks(range(self.num_classes), list(range(1, self.num_classes + 1)), rotation=45)
# 设置y轴坐标label为1, 2, 3
plt.yticks(range(self.num_classes), list(range(1, self.num_classes + 1)))
# 显示colorbar
plt.colorbar()
plt.xlabel('True Labels')
plt.ylabel('Predicted Labels')
plt.title('Confusion matrix')
# 在图中标注数量/概率信息
thresh = matrix.max() / 2
for x in range(self.num_classes):
for y in range(self.num_classes):
# 注意这里的matrix[y, x]不是matrix[x, y]
info = int(matrix[y, x])
plt.text(x, y, info,
verticalalignment='center',
horizontalalignment='center',
color="white" if info > thresh else "black")
plt.tight_layout()
plt.show()
if __name__ == '__main__':
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
data_transform = transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
# 数据集路径
data_root = r"C:\Users\yingnuo.DESKTOP-9E5CS2I\Desktop\T1-data"
assert os.path.exists(data_root), "data path {} does not exist.".format(data_root)
validate_dataset = datasets.ImageFolder(root=os.path.join(data_root, "test"),
transform=data_transform)
batch_size = 16
validate_loader = torch.utils.data.DataLoader(validate_dataset,
batch_size=batch_size, shuffle=False,
num_workers=2)
net = vgg(model_name="vgg16", num_classes=4)
# load pretrain weights
model_weight_path = "./weights/best_model.pth"
assert os.path.exists(model_weight_path), "cannot find {} file".format(model_weight_path)
weight_dict = torch.load(model_weight_path, map_location=device)
net.load_state_dict(weight_dict, strict=False)
net.to(device)
# read class_indict
json_label_path = './class_indices.json'
assert os.path.exists(json_label_path), "cannot find {} file".format(json_label_path)
json_file = open(json_label_path, 'r')
class_indict = json.load(json_file)
labels = [label for _, label in class_indict.items()]
confusion = ConfusionMatrix(num_classes=4, labels=labels)
net.eval()
with torch.no_grad():
for val_data in tqdm(validate_loader):
val_images, val_labels = val_data
outputs = net(val_images.to(device))
outputs = torch.softmax(outputs, dim=1)
outputs = torch.argmax(outputs, dim=1)
confusion.update(outputs.to("cpu").numpy(), val_labels.to("cpu").numpy())
confusion.plot()
confusion.summary()
展开全部
收起
写回答
好问题
0 提建议
追加酬金
关注问题
微信扫一扫 点击复制链接 分享 邀请回答
编辑 收藏 删除 结题 收藏 举报 追加酬金 (90%的用户在追加酬金后获得了解决方案) 当前问题酬金 ¥ 0
(可追加 ¥500)
¥ 15¥ 20¥ 50¥ 100¥ 200 支付方式 扫码支付
二维码出错
点击刷新
支付金额
15 元
提供问题酬金的用户不参与问题酬金结算和分配
支付即为同意
《付费问题酬金结算规则》 结题 再想想 删除 再等等
3条回答 默认 最新
-
下午好🌅🌅🌅
本答案参考ChatGPT-3.5
问题:VGG16得到的混淆矩阵错误,如何解决?
解决方案:
- 检查标签的数量是否与预测结果的数量匹配,确保标签数量和预测结果数量是一致的。
- 检查模型是否加载了正确的预训练权重。确保加载了正确的预训练权重文件,且文件路径正确。
- 确保数据预处理的方式与训练时一致。检查数据预处理的步骤是否与训练时使用的相同。例如,确保在测试集上使用了相同的图像缩放、剪裁和归一化方法。
- 检查模型的输出维度是否正确。确保模型的输出维度与类别数量一致。如果模型输出的维度不正确,可以修改模型的最后一层,使其输出维度与类别数量一致。
- 检查混淆矩阵的计算方法是否正确。仔细检查计算混淆矩阵的方法,确保每个预测结果都被正确地累加到混淆矩阵上。
- 检查评估指标计算的代码是否正确。确保计算精确度、召回率和F1-score等评估指标的代码正确无误。
修改后的代码如下所示:
import os
import numpy as np
import torch
from torchvision import transforms, datasets
from tqdm import tqdm
import matplotlib.pyplot as plt
from prettytable import PrettyTable
from sklearn.metrics import confusion_matrix
from model import vgg
class ConfusionMatrix(object):
def __init__(self, num_classes: int, labels: list):
self.matrix = np.zeros((num_classes, num_classes))
self.num_classes = num_classes
self.labels = labels
def update(self, preds, labels):
self.matrix += confusion_matrix(labels, preds, labels=list(range(self.num_classes)))
def summary(self):
# calculate accuracy
sum_TP = np.sum(np.diagonal(self.matrix))
acc = sum_TP / np.sum(self.matrix)
print("the model accuracy is ", acc)
# precision, recall, F1-score, specificity
table = PrettyTable()
table.field_names = ["label", "Precision", "Recall", "F1-score", "Specificity"]
for i in range(self.num_classes):
TP = self.matrix[i, i]
FP = np.sum(self.matrix[:, i]) - TP
FN = np.sum(self.matrix[i, :]) - TP
TN = np.sum(self.matrix) - TP - FP - FN
Precision = round(TP / (TP + FP), 3) if TP + FP != 0 else 0.
Recall = round(TP / (TP + FN), 3) if TP + FN != 0 else 0.
Specificity = round(TN / (TN + FP), 3) if TN + FP != 0 else 0.
f1_score = round((2*Precision*Recall)/(Precision+Recall), 3)
table.add_row([self.labels[i], Precision, Recall, f1_score, Specificity])
print(table)
def plot(self):
matrix = self.matrix
print(matrix)
plt.imshow(matrix, cmap=plt.cm.Blues)
plt.xticks(range(self.num_classes), list(range(1, self.num_classes + 1)), rotation=45)
plt.yticks(range(self.num_classes), list(range(1, self.num_classes + 1)))
plt.colorbar()
plt.xlabel('True Labels')
plt.ylabel('Predicted Labels')
plt.title('Confusion matrix')
thresh = matrix.max() / 2
for x in range(self.num_classes):
for y in range(self.num_classes):
info = int(matrix[y, x])
plt.text(x, y, info,
verticalalignment='center',
horizontalalignment='center',
color="white" if info > thresh else "black")
plt.tight_layout()
plt.show()
if __name__ == '__main__':
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
data_transform = transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
data_root = r"C:\Users\yingnuo.DESKTOP-9E5CS2I\Desktop\T1-data"
assert os.path.exists(data_root), "data path {} does not exist.".format(data_root)
validate_dataset = datasets.ImageFolder(root=os.path.join(data_root, "test"),
transform=data_transform)
batch_size = 16
validate_loader = torch.utils.data.DataLoader(validate_dataset,
batch_size=batch_size, shuffle=False,
num_workers=2)
net = vgg(model_name="vgg16", num_classes=4)
model_weight_path = "./weights/best_model.pth"
assert os.path.exists(model_weight_path), "cannot find {} file".format(model_weight_path)
weight = torch.load(model_weight_path)
net.load_state_dict(weight)
net.to(device)
net.eval()
confusion_matrix_example = ConfusionMatrix(num_classes=4, labels=['class1', 'class2', 'class3', 'class4'])
with torch.no_grad():
for inputs, labels in tqdm(validate_loader):
inputs = inputs.to(device)
labels = labels.to(device)
outputs = net(inputs)
_, preds = torch.max(outputs, 1)
confusion_matrix_example.update(preds.cpu().numpy(), labels.cpu().numpy())
confusion_matrix_example.summary()
confusion_matrix_example.plot()
修改主要包括:
- 使用
sklearn.metrics.confusion_matrix
计算混淆矩阵。 - 修改了计算精确度、召回率、F1-score和特异度的代码。
- 添加了加载预训练权重的代码。
- 将模型转移到CUDA设备上进行推断。
- 添加了用于更新和绘制混淆矩阵的代码。
你可以根据需要修改主函数中的数据路径、类别数量和标签名称等参数。请注意,这只是一个示例代码,你需要根据你的实际需求进行适当修改。另外,你需要确保已安装sklearn
、torchvision
和matplotlib
等库。
本回答被题主选为最佳回答 , 对您是否有帮助呢? 本回答被专家选为最佳回答 , 对您是否有帮助呢? 本回答被题主和专家选为最佳回答 , 对您是否有帮助呢?
解决
无用
评论
打赏
微信扫一扫 点击复制链接
分享 举报
-
【相关推荐】
如果你已经解决了该问题, 非常希望你能够分享一下解决方案, 写成博客, 将相关链接放在评论区, 以帮助更多的人 ^-^
本回答被题主选为最佳回答 , 对您是否有帮助呢? 本回答被专家选为最佳回答 , 对您是否有帮助呢? 本回答被题主和专家选为最佳回答 , 对您是否有帮助呢?
解决
无用
评论
打赏
微信扫一扫 点击复制链接
分享 举报