• 将Pytorch搭建的ViT模型转为onnx模型


    本文尝试将pytorch搭建的ViT模型转为onnx模型。

    首先将博主上一篇文章中搭建的模型ViT Vision Transformer超详细解析,网络构建,可视化,数据预处理,全流程实例教程-CSDN博客转存为.pth

    torch.save(model, 'my_vit_model.pth')

    然后新建一个py文件,要新建py文件的原因是,博主上一篇文章的main.py文件引用了很多torch相关的库,如果还是在main.py文件中运行转onnx的代码,回报错circle import 重复循环引用的错误,所以姑且将.pth作为一个中转。

    新建一个py文件,写入

    1. import importlib
    2. torch = importlib.import_module('torch')
    3. model = torch.load("my_vit_model.pth")
    4. model.cpu()
    5. # 创建一个随机的输入张量
    6. dummy_input = torch.randn(1, 3, 16, 16)
    7. torch.onnx.export(model, dummy_input, 'model.onnx', opset_version=18)

    引入importlib,通过它来引用torch也是为了解决循环引用的问题。

    这时运行这段代码,会报错onnx 不支持aten::unflatten运算。这里有两种解决方法,一种是将自己pytorch模型中的unflatten运算全部换成onnx支持的reshape函数(参见文章:https://www.cnblogs.com/antelx/p/17564039.html

    还有一种方法是,修改onnx库中的 symbolic_opset18.py 文件(/home/.local/lib/python3.8/site-packages/torch/onnx),改为如下形式

    1. """This file exports ONNX ops for opset 18.
    2. Note [ONNX Operators that are added/updated in opset 18]
    3. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    4. https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-18-of-the-default-onnx-operator-set
    5. New operators:
    6. CenterCropPad
    7. Col2Im
    8. Mish
    9. OptionalGetElement
    10. OptionalHasElement
    11. Pad
    12. Resize
    13. ScatterElements
    14. ScatterND
    15. """
    16. import functools
    17. from typing import Sequence
    18. import torch
    19. import torch._C._onnx as _C_onnx
    20. from torch.onnx import (
    21. _constants,
    22. _type_utils,
    23. errors,
    24. symbolic_helper,
    25. symbolic_opset11 as opset11,
    26. symbolic_opset9 as opset9,
    27. utils,
    28. )
    29. from torch.onnx._internal import _beartype, jit_utils, registration
    30. from torch import _C
    31. from torch.onnx import symbolic_helper
    32. from torch.onnx._internal import _beartype, registration
    33. # EDITING THIS FILE? READ THIS FIRST!
    34. # see Note [Edit Symbolic Files] in symbolic_helper.py
    35. __all__ = ["col2im"]
    36. _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=18)
    37. @_onnx_symbolic("aten::col2im")
    38. @symbolic_helper.parse_args("v", "v", "v", "is", "is", "is")
    39. @_beartype.beartype
    40. def col2im(
    41. g,
    42. input: _C.Value,
    43. output_size: _C.Value,
    44. kernel_size: _C.Value,
    45. dilation: Sequence[int],
    46. padding: Sequence[int],
    47. stride: Sequence[int],
    48. ):
    49. # convert [i0, i1, ..., in] into [i0, i0, i1, i1, ..., in, in]
    50. adjusted_padding = []
    51. for pad in padding:
    52. for _ in range(2):
    53. adjusted_padding.append(pad)
    54. num_dimensional_axis = symbolic_helper._get_tensor_sizes(output_size)[0]
    55. if not adjusted_padding:
    56. adjusted_padding = [0, 0] * num_dimensional_axis
    57. if not dilation:
    58. dilation = [1] * num_dimensional_axis
    59. if not stride:
    60. stride = [1] * num_dimensional_axis
    61. return g.op(
    62. "Col2Im",
    63. input,
    64. output_size,
    65. kernel_size,
    66. dilations_i=dilation,
    67. pads_i=adjusted_padding,
    68. strides_i=stride,
    69. )
    70. @_onnx_symbolic("aten::unflatten")
    71. def unflatten(g:jit_utils.GraphContext, input, dim, unflattened_size):
    72. input_dim = symbolic_helper._get_tensor_rank(input)
    73. if input_dim is None:
    74. return symbolic_helper._unimplemented(
    75. "dim",
    76. "ONNX and PyTorch use different strategies to split the input. "
    77. "Input rank must be known at export time.",
    78. )
    79. # dim could be negative
    80. input_dim = g.op("Constant", value_t=torch.tensor([input_dim], dtype=torch.int64))
    81. dim = g.op("Add", input_dim, dim)
    82. dim = g.op("Mod", dim, input_dim)
    83. input_size = g.op("Shape", input)
    84. head_start_idx = g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64))
    85. head_end_idx = g.op(
    86. "Reshape", dim, g.op("Constant", value_t=torch.tensor([1], dtype=torch.int64))
    87. )
    88. head_part_rank = g.op("Slice", input_size, head_start_idx, head_end_idx)
    89. dim_plus_one = g.op(
    90. "Add", dim, g.op("Constant", value_t=torch.tensor([1], dtype=torch.int64))
    91. )
    92. tail_start_idx = g.op(
    93. "Reshape",
    94. dim_plus_one,
    95. g.op("Constant", value_t=torch.tensor([1], dtype=torch.int64)),
    96. )
    97. tail_end_idx = g.op(
    98. "Constant", value_t=torch.tensor([_constants.INT64_MAX], dtype=torch.int64)
    99. )
    100. tail_part_rank = g.op("Slice", input_size, tail_start_idx, tail_end_idx)
    101. final_shape = g.op(
    102. "Concat", head_part_rank, unflattened_size, tail_part_rank, axis_i=0
    103. )
    104. return symbolic_helper._reshape_helper(g, input, final_shape)

    这里这样做是相当于自己在onnx库中注册aten::unflatten运算。

    再新建一个py文件,写入

    1. import onnxruntime as rt
    2. import numpy as np
    3. # 加载模型
    4. sess = rt.InferenceSession("model.onnx")
    5. # 获取输入和输出名称
    6. input_name = sess.get_inputs()[0].name
    7. output_name = sess.get_outputs()[0].name
    8. # 创建输入数据
    9. input_data = np.random.rand(1, 3, 16, 16).astype(np.float32)
    10. # 运行模型
    11. pred_onnx = sess.run([output_name], {input_name: input_data})
    12. # 打印预测结果
    13. print(pred_onnx)

    就可以运行onnx模型了。

  • 相关阅读:
    Python爬虫——BautifulSoup 常用函数的使用
    获取线上手机App日志
    Ubuntu 20.04 下 APT 安装 mysql-8.0 并配置 root 远程访问
    CloudService计算类技术和网络类技术以及存储类技术的基础学习
    java毕业设计软件B2C婚纱摄影网站的设计与实现S2SH[包运行成功]
    pip-script.py‘ is not present Verifying transaction: failed
    油封有哪些材料可供选择?
    ASP.NET Core的几种服务器类型[共6篇]
    前端工作总结114-JS-JS创建数组的三种方法
    百度飞桨EasyDL X 韦士肯:看轴承质检如何装上“AI之眼”
  • 原文地址:https://blog.csdn.net/qq_41816368/article/details/134209626