- import torch
- import matplotlib.pyplot as plt
- import json
- from model import AlexNet
- from PIL import Image
- from torchvision import transforms
- data_transform = transforms.Compose(
- [transforms.Resize((224, 224)), # 将图片重新裁剪
- transforms.ToTensor(), # 转化为tensor
- transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) # 标准化数据
- # load image
- img = Image.open("1.jpeg") # 网上随便下载,放到好找的路径下
- plt.imshow(img) # 直接载入图像
- img = data_transform(img) 在预处理过程中吧channel提到前面
- img = torch.unsqueeze(img, dim=0) # 添加batch维度
- # read class_indent
- try:
- # 读取保存在json文件中索引对应的类别名称
- json_file = open('./class_indices,json', 'r')
- class_indict = json.load(json_file) # 将json文件解码成字典格式
- except Exception as e:
- print(e)
- exit(-1)
output = torch.squeeze(model(img)):先将图片通过正向传播得到输出,再把输出的batch压缩
predict = torch.softmax(output, dim=0):通过softmax得到一个概率分布
predict_cla = torch.argmax(predict).numpy():找到概率最大处所对应的索引值
print将类别名称和预测概率输出
- # create model
- model = AlexNet(num_classes=5)
- model_weight_path = "./AlexNet.pth"
- model.load_state_dict(torch.load(model_weight_path)) # 载入网络模型
- model.eval() # 关闭dropout
- with torch.no_grad():
- output = torch.squeeze(model(img))
- predict = torch.softmax(output, dim=0)
- predict_cla = torch.argmax(predict).numpy()
- print(class_indict[str(predict_cla)], predict[predict_cla].item())
- plt.show()
容易把玫瑰识别成郁金香,把蒲公英识别成向日葵,郁金香,向日葵,小雏菊可以很好的识别出来,模型的准确率还是有点低。大家自己尝试测试一下吧哈哈。
PyTorch搭建AlexNet网络合集:
PyTorch搭建AlexNet网络模型-CSDN博客