• Pytorch训练模型模型转Onnx推理模型及推理测试(通用全流程)


    环境依赖

    Python环境依赖

    CPU版本推理:onnxruntime
    GPU版本推理:onnxruntime-gpu
    torchvision
    PIL
    torch
    netron
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    模型转化

    PyTorch 模型转换成 ONNX 模型时,我们往往只需要轻松地调用一句torch.onnx.export就可以了。这个函数的接口看上去简单,但它在使用上还有着诸多的注意事项。

    前三个必选参数为模型、模型输入、导出的 onnx 文件名,我们对这几个参数已经很熟悉了。我们来着重看一下后面的一些常用可选参数。

    • export_params
      模型中是否存储模型权重。一般中间表示包含两大类信息:模型结构和模型权重,这两类信息可以在同一个文件里存储,也可以分文件存储。ONNX 是用同一个文件表示记录模型的结构和权重的。
      我们部署时一般都默认这个参数为 True。如果 onnx 文件是用来在不同框架间传递模型(比如 PyTorch 到 Tensorflow)而不是用于部署,则可以令这个参数为 False。
    • input_names, output_names
      设置输入和输出张量的名称。如果不设置的话,会自动分配一些简单的名字(如数字)。
      ONNX 模型的每个输入和输出张量都有一个名字。很多推理引擎在运行 ONNX 文件时,都需要以“名称-张量值”的数据对来输入数据,并根据输出张量的名称来获取输出数据。在进行跟张量有关的设置(比如添加动态维度)时,也需要知道张量的名字。
      在实际的部署流水线中,我们都需要设置输入和输出张量的名称,并保证 ONNX 和推理引擎中使用同一套名称。
    • opset_version
      转换时参考哪个 ONNX 算子集版本,默认为 9。后文会详细介绍 PyTorch 与 ONNX 的算子对应关系。
    • dynamic_axes
      指定输入输出张量的哪些维度是动态的。

    附代码

    import torch
    
    # 这一块区域为模型加载的步骤具体可以依据自己使用情况替换
    from model import efficientnetv2_s as create_model
    device = "cpu"
    model = create_model(num_classes=2).to(device)
    model_weight_path = "./weights1/model-54.pth"
    model.load_state_dict(torch.load(model_weight_path, map_location=device))
    model.eval()
    
    
    
    batch_size = 1  # 批处理大小
    input_shape = (3, 224, 224)  # 输入数据
    
    x = torch.randn(batch_size, *input_shape)  # 生成张量
    export_onnx_file = "test.onnx"  # 目的ONNX文件名
    torch.onnx.export(model,
                      x,
                      export_onnx_file,
                      opset_version=10,
                      do_constant_folding=True,  # 是否执行常量折叠优化
                      input_names=["input"],  # 输入名
                      output_names=["output"],  # 输出名
                      dynamic_axes={"input": {0: "batch_size"},  # 批处理变量
                                    "output": {0: "batch_size"}})
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26

    onnx模型可视化

    import netron
    
    modelData = "./test.onnx"
    netron.start(modelData)
    
    • 1
    • 2
    • 3
    • 4

    在这里插入图片描述

    使用onnx推理

    import os, sys
    
    sys.path.append(os.getcwd())
    import onnxruntime
    import torchvision.models as models
    import torchvision.transforms as transforms
    from PIL import Image
    
    
    def to_numpy(tensor):
        return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
    
    # 自定义的数据增强
    def get_test_transform(): 
        return transforms.Compose([
            transforms.Resize([224, 224]),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
    
    # 推理的图片路径
    image = Image.open('./0a0b8641cac0ce40315e38af020bb18f-device4-0-f_items6.jpg').convert('RGB')
    
    img = get_test_transform()(image)
    img = img.unsqueeze_(0)  # -> NCHW, 1,3,224,224
    # 模型加载
    onnx_model_path = "test.onnx"
    resnet_session = onnxruntime.InferenceSession(onnx_model_path)
    inputs = {resnet_session.get_inputs()[0].name: to_numpy(img)}
    outs = resnet_session.run(None, inputs)[0]
    
    print("onnx weights", outs)
    print("onnx prediction", outs.argmax(axis=1)[0])
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
  • 相关阅读:
    Matlab 中值滤波原理分析
    【文件读取/包含】任意文件读取漏洞 afr_3
    这几个音乐伴奏提取的方法快码住了
    面试心经
    汽车网络安全 -- ECU会遭受黑客怎样的攻击?
    2023国赛数学建模C题思路代码 - 蔬菜类商品的自动定价与补货决策
    Python 实践
    立哥国家示范项目-5G智慧文旅
    DataFunSummit:2023年数据基础架构峰会-核心PPT资料下载
    开关、电机、断路器、电热偶、电表接线图大全
  • 原文地址:https://blog.csdn.net/xiao_yan_/article/details/133158944