• Pytorch模型转ONNX部署


    开始以为会很困难,但是其实非常方便,下边分两步走:1. pytorch模型转onnx;2. 使用onnx进行inference

    0. 准备工作

    0.1 安装onnx

    安装onnx和onnxruntime,onnx貌似是个环境。。倒是没有直接使用,onnxruntime是一个onnx的架构,方便部署使用的

    CPU版本:

    pip install onnx -i http://pypi.douban.com/simple/  --trusted-host pypi.douban.com
    pip install onnxruntime -i http://pypi.douban.com/simple/  --trusted-host pypi.douban.com

    GPU版本:

    pip install onnx -i http://pypi.douban.com/simple/  --trusted-host pypi.douban.com
    pip install onnxruntime-gpu  -i http://pypi.douban.com/simple/  --trusted-host pypi.douban.com

    1. pytorch模型转ONNX

    1. ### 导出onnx模型
    2. torch.onnx.export(self.network, {'input dict': input dict}, 'home3/medcog/pbliu/test_onnx.onnx')
    3. print('output a onnx model!!!!!!')

    坑1:dummy input那里的那个dict:{'input_dict': input_dict},'input_dict'是我network中forward中的参数名字,后边的input_dict是实际的数据,batch size=1。

    坑2:只是为了用的话,export三个参数就够了:网络,虚拟输入(bs=1),保存路径。这时候输入的名字会按照顺序被替换掉"onnx::Cast_*",所以你把输入对回去就可以了,我的数据格式修改如下。(并且onnx只接受numpy格式)

    1. onnx_dict = {}
    2. key_prefix = 'onnx::Cast__{}'
    3. onnx_idx = 1
    4. for idx, (k,v) in enumerate(input_dict.items()):
    5. if k.startswith('input'):
    6. onnx_dict[key_prefix.format(onnx_idx)] = v.numpy()
    7. onnx_idx += 1

    2. 如何用onnx进行inference

    1. import onnxruntime as rt
    2. import numpy as np
    3. # 加载 ONNX 模型
    4. sess = rt.InferenceSession('my_model.onnx', providers=['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'])
    5. # 准备好数据onnx_dict
    6. # 调用模型进行推理
    7. result = sess.run(None, onnx_dict)

    坑3:这里的sess.run中的None应该类似于tf中希望得到的结果,我这里没有命名,所以就写None了,会默认返回你之前pytorch输出的变量

    坑4:sess.run使用的数据onnx_dict就是'onnx::Cast_*'和np array的键值对儿了,你之前在pytorch中定义的输入格式都不重要了,不管你是dict还是啥。

    坑5. onnxruntime gpu的时候可能会报错,一个可能是cuda版本不适配的问题,直接在虚拟环境中安装对应版本的cuda就可以

    1. conda install cudatoolkit=10.1
    2. # 版本对照参考https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html

    一些其他tips:

    1. 实操时候遇到一个极蠢的问题,onnx比pytorch慢很多,后来发现是我把初始化写到运行代码中了,每次测试一个数据都会重新初始化一遍。

  • 相关阅读:
    16. 机器学习 - 决策树
    CSDN每日一练 |『小艺读书』『小Q的鲜榨柠檬汁』『分层遍历二叉树』2023-10-19
    2024年java面试--mysql(3)
    嵌入式软件设计之美-以实际项目应用MVC框架与状态模式(上)
    看完抱你学会Exchanger
    ES 集群常用排查命令
    git--工作区、暂存区、本地仓库、远程仓库
    vue-router(路由)详细
    Redis配置与优化
    Docker的网络模式
  • 原文地址:https://blog.csdn.net/Eric_Evil/article/details/132913536