本文尝试将pytorch搭建的ViT模型转为onnx模型。
首先将博主上一篇文章中搭建的模型ViT Vision Transformer超详细解析,网络构建,可视化,数据预处理,全流程实例教程-CSDN博客转存为.pth
torch.save(model, 'my_vit_model.pth')
然后新建一个py文件,要新建py文件的原因是,博主上一篇文章的main.py文件引用了很多torch相关的库,如果还是在main.py文件中运行转onnx的代码,回报错circle import 重复循环引用的错误,所以姑且将.pth作为一个中转。
新建一个py文件,写入
- import importlib
- torch = importlib.import_module('torch')
-
-
- model = torch.load("my_vit_model.pth")
-
-
- model.cpu()
- # 创建一个随机的输入张量
- dummy_input = torch.randn(1, 3, 16, 16)
- torch.onnx.export(model, dummy_input, 'model.onnx', opset_version=18)
引入importlib,通过它来引用torch也是为了解决循环引用的问题。
这时运行这段代码,会报错onnx 不支持aten::unflatten运算。这里有两种解决方法,一种是将自己pytorch模型中的unflatten运算全部换成onnx支持的reshape函数(参见文章:https://www.cnblogs.com/antelx/p/17564039.html)
还有一种方法是,修改onnx库中的 symbolic_opset18.py 文件(/home/.local/lib/python3.8/site-packages/torch/onnx),改为如下形式
- """This file exports ONNX ops for opset 18.
- Note [ONNX Operators that are added/updated in opset 18]
- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-18-of-the-default-onnx-operator-set
- New operators:
- CenterCropPad
- Col2Im
- Mish
- OptionalGetElement
- OptionalHasElement
- Pad
- Resize
- ScatterElements
- ScatterND
- """
-
- import functools
- from typing import Sequence
-
- import torch
- import torch._C._onnx as _C_onnx
- from torch.onnx import (
- _constants,
- _type_utils,
- errors,
- symbolic_helper,
- symbolic_opset11 as opset11,
- symbolic_opset9 as opset9,
- utils,
- )
- from torch.onnx._internal import _beartype, jit_utils, registration
-
- from torch import _C
- from torch.onnx import symbolic_helper
- from torch.onnx._internal import _beartype, registration
-
- # EDITING THIS FILE? READ THIS FIRST!
- # see Note [Edit Symbolic Files] in symbolic_helper.py
-
- __all__ = ["col2im"]
-
- _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=18)
-
-
- @_onnx_symbolic("aten::col2im")
- @symbolic_helper.parse_args("v", "v", "v", "is", "is", "is")
- @_beartype.beartype
- def col2im(
- g,
- input: _C.Value,
- output_size: _C.Value,
- kernel_size: _C.Value,
- dilation: Sequence[int],
- padding: Sequence[int],
- stride: Sequence[int],
- ):
- # convert [i0, i1, ..., in] into [i0, i0, i1, i1, ..., in, in]
- adjusted_padding = []
- for pad in padding:
- for _ in range(2):
- adjusted_padding.append(pad)
-
- num_dimensional_axis = symbolic_helper._get_tensor_sizes(output_size)[0]
- if not adjusted_padding:
- adjusted_padding = [0, 0] * num_dimensional_axis
-
- if not dilation:
- dilation = [1] * num_dimensional_axis
-
- if not stride:
- stride = [1] * num_dimensional_axis
-
- return g.op(
- "Col2Im",
- input,
- output_size,
- kernel_size,
- dilations_i=dilation,
- pads_i=adjusted_padding,
- strides_i=stride,
- )
-
-
-
- @_onnx_symbolic("aten::unflatten")
- def unflatten(g:jit_utils.GraphContext, input, dim, unflattened_size):
- input_dim = symbolic_helper._get_tensor_rank(input)
- if input_dim is None:
- return symbolic_helper._unimplemented(
- "dim",
- "ONNX and PyTorch use different strategies to split the input. "
- "Input rank must be known at export time.",
- )
-
- # dim could be negative
- input_dim = g.op("Constant", value_t=torch.tensor([input_dim], dtype=torch.int64))
- dim = g.op("Add", input_dim, dim)
- dim = g.op("Mod", dim, input_dim)
-
- input_size = g.op("Shape", input)
-
- head_start_idx = g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64))
- head_end_idx = g.op(
- "Reshape", dim, g.op("Constant", value_t=torch.tensor([1], dtype=torch.int64))
- )
- head_part_rank = g.op("Slice", input_size, head_start_idx, head_end_idx)
-
- dim_plus_one = g.op(
- "Add", dim, g.op("Constant", value_t=torch.tensor([1], dtype=torch.int64))
- )
- tail_start_idx = g.op(
- "Reshape",
- dim_plus_one,
- g.op("Constant", value_t=torch.tensor([1], dtype=torch.int64)),
- )
- tail_end_idx = g.op(
- "Constant", value_t=torch.tensor([_constants.INT64_MAX], dtype=torch.int64)
- )
- tail_part_rank = g.op("Slice", input_size, tail_start_idx, tail_end_idx)
-
- final_shape = g.op(
- "Concat", head_part_rank, unflattened_size, tail_part_rank, axis_i=0
- )
-
- return symbolic_helper._reshape_helper(g, input, final_shape)
这里这样做是相当于自己在onnx库中注册aten::unflatten运算。
再新建一个py文件,写入
- import onnxruntime as rt
- import numpy as np
-
- # 加载模型
- sess = rt.InferenceSession("model.onnx")
-
- # 获取输入和输出名称
- input_name = sess.get_inputs()[0].name
- output_name = sess.get_outputs()[0].name
-
- # 创建输入数据
- input_data = np.random.rand(1, 3, 16, 16).astype(np.float32)
-
- # 运行模型
- pred_onnx = sess.run([output_name], {input_name: input_data})
-
- # 打印预测结果
- print(pred_onnx)
就可以运行onnx模型了。