码农知识堂 - 1000bd
  •   Python
  •   PHP
  •   JS/TS
  •   JAVA
  •   C/C++
  •   C#
  •   GO
  •   Kotlin
  •   Swift
  • 【TensorRT】PyTorch模型转换为ONNX及TensorRT模型


    文章目录

    • 1. PyTorch模型转TensorRT模型流程
    • 2. PyTorch模型转ONNX模型
    • 3. ONNX模型转TensorRT模型
      • 3.1 TensorRT安装
      • 3.2 将ONNX模型转换为TensorRT模型
    • 4. TensorRT在Python中推理
    • 5. 转换TensorRT需要注意的一些语法规则
      • 5.1 tensor索引不支持bool类型作为索引参数
      • 5.2 squeeze()会导致ONNX模型出现 if 节点
      • 5.3 argmax需要至少2维的tensor作为输入
      • 5.4 expand操作转换TensorRT问题
      • 5.5 repeat操作转换TensorRT问题
      • 5.6 dimensions not compatible for scatterND
      • 5.7 数据类型为int32的多维input传入TensorRT与原数据不等
      • 5.8 定位TensorRT的错误fusion,并拆分错误fusion
      • 5.9 TensorRT不支持torch.topk()动态k值

    任务简介:
    TensorRT 模型的推理速度比 libtorch 模型更快,所以 PyTorch 模型转换为 TensorRT 模型部署几乎是最好的选择。通常TensorRT 模型首先需要转换为 ONNX 模型,再由 ONNX 模型转换为TensorRT 模型。本文对转换方法及一些注意点做一个记录。
    在这里插入图片描述


    1. PyTorch模型转TensorRT模型流程

    在这里插入图片描述

    2. PyTorch模型转ONNX模型

    pytorch模型转onnx模型示例:

    torch.onnx.export(model,                          # model being run
                      onnx_inputs,                    # model input (or a tuple for multiple inputs)
                      "trt/dense_tnt.onnx",           # where to save the model (can be a file or file-like object)
                      export_params=True,             # store the trained parameter weights inside the model file
                      verbose=True,                   # if True, all parameters will be exported
                      opset_version=11,               # the ONNX version to export the model to
                      do_constant_folding=False,      # whether to execute constant folding for optimization
                      input_names=['vector_data', 'all_polyline_idx', 'vector_idx', 'invalid_idx', 'map_polyline_idx', 'traj_polyline_idx', 'cent_polyline_idx', 'topk_points_idx'],   # the model's input names
                      output_names=['trajectories'],  # the model's output names
                      dynamic_axes={
       'all_polyline_idx': {
       1: 'all_polyline_num'},
                                    'vector_idx': {
       1: 'vector_num'},
                                    'invalid_idx': {
       1: 'invalid_num'},
                                    'map_polyline_idx': {
       1: 'map_polyline_num'},
                                    'traj_polyline_idx': {
       1: 'traj_polyline_num'},
                                    'cent_polyline_idx': {
       1: 'cent_polyline_num'},
                                    'topk_points_idx': {
       1: 'topk_points_num'},})
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24

    verbose=True 可以将转换后的模型代码及参数输出,并对应了相应的源代码。
    dynamic_axes 可以设置动态输入,如 'all_polyline_idx': {1: 'all_polyline_num'} 表示 all_polyline_idx 的第1维的shape为动态,且命名为 all_polyline_num。

    ONNX 模型简化:

    onnxsim input_onnx_model_name output_onnx_model_name
    
    • 1

    简化后输出:

    在这里插入图片描述

    3. ONNX模型转TensorRT模型

    3.1 TensorRT安装

    • 点击到官网下载,选择自己需要的版本,需要nvidia账号。

    • 新建文件夹,将压缩文件拷贝进来解压:

    tar xzvf TensorRT-8.4.3.1.Linux.x86_64-gnu.cuda-11.6.cudnn8.4.tar.gz
    
    • 1
    • 解压得到TensorRT-8.4.3.1的文件夹,将里面的lib绝对路径添加到环境变量中:
    export LD_LIBRARY_PATH=${TENSORRT_PATH}/TensorRT-8.4.3.1/lib:$LD_LIBRARY_PATH
    export LIBRARY_PATH=${TENSORRT_PATH}/TensorRT-8.4.3.1/lib:$LIBRARY_PATH
    
    • 1
    • 2
    • 使用pip命令安装TensorRT:
    cd TensorRT-8.4.3.1/python/
    pip install tensorrt-8.4.3.1-cp38-none-linux_x86_64.whl
    
    • 1
    • 2

    3.2 将ONNX模型转换为TensorRT模型

    转换命令:

    ${TENSORRT_PATH}/TensorRT-8.4.2.4/bin/trtexec 
    --onnx=dense_tnt_sim.onnx 
    --minShapes=all_polyline_idx:1x7,vector_idx:1x6,invalid_idx:1x1,map_polyline_idx:1x6,traj_polyline_idx:1x1,cent_polyline_idx:1x3,topk_points_idx:1x50 
    --optShapes=all_polyline_idx:1x1200,vector_idx:1x24000,invalid_idx:1x24000,map_polyline_idx:1x1200,traj_polyline_idx:1x1200,cent_polyline_idx:1x1200,topk_points_idx:1x24000 
    --maxShapes=all_polyline_idx:1x1200,vector_idx:1x24000,invalid_idx:1x24000,map_polyline_idx:1x1200,traj_polyline_idx:1x1200,cent_polyline_idx:1x1200,topk_points_idx:1x24000 
    --saveEngine=dense_tnt_fp32.engine 
    --device=0
    --workspace=48000
    --noTF32
    --verbose
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    --minShapes 为inputs的最小维度;
    --optShapes 为输入常用的inputs维度,我这边输入的是最大维度;
    --maxShapes 为inputs的最大维度;
    --device 设置转换模型使用的gpu;
    --noTF32 不使用tf32数据类型,使用fp32;
    --verbose 输出详细信息。

    4. TensorRT在Python中推理

    整体流程:

    if __name__ == '__main__':
    	device = torch.device("cpu")  # TensorRT模型不管device是cpu还是cuda都会调用gpu,为了方便转换用cpu可以解决device冲突问题
    	trt_path = "/home/chenxin/peanut/DenseTNT/trt/densetnt_vehicle_trt_model_tf32.engine"
    	input_names = ["vector_data", "all_polyline_idx", "vector_idx", "invalid_idx", "map_polyline_idx", "traj_polyline_idx", 
    	"cent_polyline_idx", "topk_points_idx"]
    	output_names = ["1195", "1475", "1785"
    • 1
    • 2
    • 3
    • 4
    • 5
  • 相关阅读:
    记一次 .NET 某娱乐聊天流平台 CPU 爆高分析
    Python Flask MongoDB Web开发:前 言
    ElementUI validate 验证结果错误的问题解决过程
    2024年反电诈重点:打击帮信罪&掩隐罪
    pm2工具的介绍
    安装lrzsz
    【Rust日报】2023-10-14 Rust101: 使用光线跟踪渲染Cornell box
    TCP串流场景剖析
    (附源码)ssm高校志愿者服务系统 毕业设计 011648
    从事前端真的没有后端工资高?
  • 原文地址:https://blog.csdn.net/weixin_40633696/article/details/126897389
  • 最新文章
  • 攻防演习之三天拿下官网站群
    数据安全治理学习——前期安全规划和安全管理体系建设
    企业安全 | 企业内一次钓鱼演练准备过程
    内网渗透测试 | Kerberos协议及其部分攻击手法
    0day的产生 | 不懂代码的"代码审计"
    安装scrcpy-client模块av模块异常,环境问题解决方案
    leetcode hot100【LeetCode 279. 完全平方数】java实现
    OpenWrt下安装Mosquitto
    AnatoMask论文汇总
    【AI日记】24.11.01 LangChain、openai api和github copilot
  • 热门文章
  • 十款代码表白小特效 一个比一个浪漫 赶紧收藏起来吧!!!
    奉劝各位学弟学妹们,该打造你的技术影响力了!
    五年了,我在 CSDN 的两个一百万。
    Java俄罗斯方块,老程序员花了一个周末,连接中学年代!
    面试官都震惊,你这网络基础可以啊!
    你真的会用百度吗?我不信 — 那些不为人知的搜索引擎语法
    心情不好的时候,用 Python 画棵樱花树送给自己吧
    通宵一晚做出来的一款类似CS的第一人称射击游戏Demo!原来做游戏也不是很难,连憨憨学妹都学会了!
    13 万字 C 语言从入门到精通保姆级教程2021 年版
    10行代码集2000张美女图,Python爬虫120例,再上征途
Copyright © 2022 侵权请联系2656653265@qq.com    京ICP备2022015340号-1
正则表达式工具 cron表达式工具 密码生成工具

京公网安备 11010502049817号