• inference.py篇


    inference.py 篇

    目录:

    • 前言
    • 思考自己需要载入的超参
    • 书写代码
    • 函数手册

    前言

    在该模块中加载训练好的模型,对测试集的image进行推理。

    思考自己需要载入的超参

    该模块的书写,是train的简约版,例如你可能需要设置和train相同的batch_sizedevicedataloader等信息,但是这次你不需要设置epoch等信息,对模型的参数进行优化等。

    书写代码

    书写顺序如下:

    argparse()方法收集需要传递的所有参数,传入main函数中(可选)。

    main函数中思路如下:

    1. 写路径等信息
    2. 书写dataloder。设置transformsdatasetdataloaderbatch_size等参数,因为dataloader中要用到。
    3. 设置其余超参,如device等,这次你必须要加载train中产生的预训练权重。
    4. 对测试集进行推理

    下以AlexNet中的inference.py为例:

    # add path
    import os, sys
    root_path = os.path.dirname(os.path.dirname(__file__))
    project_path = os.path.dirname(__file__)
    sys.path.append(project_path)
    # add module
    from PIL import Image
    from torchvision import transforms
    import matplotlib.pyplot as plt
    import json
    import torch
    import numpy as np
    from model import AlexNet
    
    
    def parse_args():
    	"""get your args"""
    	
    def convert_image(image_path:str = ""):
        """transform png to jpg"""
    
    def main():
        # 路径
        root_path       = os.path.dirname(os.path.dirname(__file__))
        project_path    = os.path.dirname(__file__)
        weight_path     = os.path.join(root_path, "weight", "AlexNet_2.pth")
        image_path      = "/home/yingmuzhi/AlexNet/daisy.jpg"
        # 加载预测图片
        img             = None
        data_transform  = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        img = Image.open(image_path)
        print(np.array(img).shape)
        img = data_transform(img)   # 只接受[height, width, channel=3]的图片, 即RGB的jpg
        img = torch.unsqueeze(img, dim = 0) # 传入网络需要[batch, channel, height, width]
        # 加载json文件
        try:
            json_file = open(project_path + "/class_indices.json","r")
            class_indict = json.load(json_file)
        except Exception as e:
            print(e)
            exit(-1)
        # 测试参数
        net = AlexNet(num_classes=2)
        net.load_state_dict(torch.load(weight_path))
        net.eval()    # 关闭dropout层并且不会梯度回传
        with torch.no_grad():
            # predict class
            output = net(img)
            # print(output.shape)
            output = torch.squeeze(output)
            # print(output.shape)
            predict = torch.softmax(output, dim = 0)
            # print(predict.shape)
            predict_cla = torch.argmax(predict).numpy()
        print(class_indict[str(predict_cla)], predict[predict_cla].item())
    
    if __name__ == "__main__":
    	args = parse_args()
        main(args)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63

    函数手册

  • 相关阅读:
    计算机系统(20)----- 信号量机制
    Docker常用命令
    esp32s3通过mqtt协议连接阿里云(wifi+mqtt+vscode+espidf4.4.4+py3.8.7)
    【数据挖掘】数据挖掘、关联分析、分类预测、决策树、聚类、类神经网络与罗吉斯回归
    LeetCode 212.单词搜索Ⅱ Python题解
    CSRF防范介绍之一
    战略合作|SubQuery 成为章鱼网络浏览器的秘密武器
    RISC-V Reader 笔记(七)RV64,特权架构,未来可选扩展
    使用Visual Studio Code 进行Python编程(三)
    HarmonyOS开发实例:【分布式手写板】
  • 原文地址:https://blog.csdn.net/qq_43369406/article/details/127932629