• torch.onnx.export详细介绍


    目录

    函数原型

    参数介绍

    mode (torch.nn.Module, torch.jit.ScriptModule or torch.jit.ScriptFunction)

    args (tuple or torch.Tensor)

    f

    export_params (bool, default True)

    verbose (bool, default False)

    training (enum, default TrainingMode.EVAL)

    input_names (list of str, default empty list)

    output_names (list of str, default empty list)

    operator_export_type (enum, default None)

    opset_version (int, default 9)

    do_constant_folding (bool, default False)

    example_outputs (T or a tuple of T, where T is Tensor or convertible to Tensor, default None)

    dynamic_axes (dict> or dict, default empty dict),>,>

    keep_initializers_as_inputs (bool, default None)

    custom_opsets (dict, default empty dict),>


    函数原型

    参数介绍

    • mode (torch.nn.Module, torch.jit.ScriptModule or torch.jit.ScriptFunction)

    需要转换的模型,支持的模型类型有:torch.nn.Module, torch.jit.ScriptModule or torch.jit.ScriptFunction

    • args (tuple or torch.Tensor)

    args可以被设置成三种形式

    1.一个tuple

    args = (x, y, z)
    
    • 1

    这个tuple应该与模型的输入相对应,任何非Tensor的输入都会被硬编码入onnx模型,所有Tensor类型的参数会被当做onnx模型的输入。

    2.一个Tensor

    args = torch.Tensor([1, 2, 3])
    
    • 1

    一般这种情况下模型只有一个输入

    3.一个带有字典的tuple

    args = (x,
            {'y': input_y,
             'z': input_z})
    
    • 1
    • 2
    • 3

    这种情况下,所有字典之前的参数会被当做“非关键字”参数传入网络,字典种的键值对会被当做关键字参数传入网络。如果网络中的关键字参数未出现在此字典中,将会使用默认值,如果没有设定默认值,则会被指定为None。

    NOTE:

    一个特殊情况,当网络本身最后一个参数为字典时,直接在tuple最后写一个字典则会被误认为关键字传参。所以,可以通过在tuple最后添加一个空字典来解决。

    #错误写法:
    
    torch.onnx.export(
        model,
        (x,
         # WRONG: will be interpreted as named arguments
         {y: z}),
        "test.onnx.pb")
    
    # 纠正
    
    torch.onnx.export(
        model,
        (x,
         {y: z},
         {}),
        "test.onnx.pb")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • f

    一个文件类对象或一个路径字符串,二进制的protocol buffer将被写入此文件

    • export_params (bool, default True)

    如果为True则导出模型的参数。如果想导出一个未训练的模型,则设为False

    • verbose (bool, default False)

    如果为True,则打印一些转换日志,并且onnx模型中会包含doc_string信息。

    • training (enum, default TrainingMode.EVAL)

    枚举类型包括:

    TrainingMode.EVAL - 以推理模式导出模型。

    TrainingMode.PRESERVE - 如果model.training为False,则以推理模式导出;否则以训练模式导出。

    TrainingMode.TRAINING - 以训练模式导出,此模式将禁止一些影响训练的优化操作。

    • input_names (list of str, default empty list)

    按顺序分配给onnx图的输入节点的名称列表。

    • output_names (list of str, default empty list)

    按顺序分配给onnx图的输出节点的名称列表。

    • operator_export_type (enum, default None)

    默认为OperatorExportTypes.ONNX, 如果Pytorch built with DPYTORCH_ONNX_CAFFE2_BUNDLE,则默认为OperatorExportTypes.ONNX_ATEN_FALLBACK。

    枚举类型包括:

    OperatorExportTypes.ONNX - 将所有操作导出为ONNX操作。

    OperatorExportTypes.ONNX_FALLTHROUGH - 试图将所有操作导出为ONNX操作,但碰到无法转换的操作(如onnx未实现的操作),则将操作导出为“自定义操作”,为了使导出的模型可用,运行时必须支持这些自定义操作。支持自定义操作方法见链接

    OperatorExportTypes.ONNX_ATEN - 所有ATen操作导出为ATen操作,ATen是Pytorch的内建tensor库,所以这将使得模型直接使用Pytorch实现。(此方法转换的模型只能被Caffe2直接使用)

    OperatorExportTypes.ONNX_ATEN_FALLBACK - 试图将所有的ATen操作也转换为ONNX操作,如果无法转换则转换为ATen操作(此方法转换的模型只能被Caffe2直接使用)。例如:

    # 转换前:
    graph(%0 : Float):
      %3 : int = prim::Constant[value=0]()
      # conversion unsupported
      %4 : Float = aten::triu(%0, %3)
      # conversion supported
      %5 : Float = aten::mul(%4, %0)
      return (%5)
    
    
    # 转换后:
    graph(%0 : Float):
      %1 : Long() = onnx::Constant[value={0}]()
      # not converted
      %2 : Float = aten::ATen[operator="triu"](%0, %1)
      # converted
      %3 : Float = onnx::Mul(%2, %0)
      return (%3)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • opset_version (int, default 9)

    默认是9。值必须等于_onnx_main_opset或在_onnx_stable_opsets之内。具体可在torch/onnx/symbolic_helper.py中找到。例如:

    _default_onnx_opset_version = 9
    
    _onnx_main_opset = 13
    
    _onnx_stable_opsets = [7, 8, 9, 10, 11, 12]
    
    _export_onnx_opset_version = _default_onnx_opset_version
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • do_constant_folding (bool, default False)

    是否使用“常量折叠”优化。常量折叠将使用一些算好的常量来优化一些输入全为常量的节点。

    • example_outputs (T or a tuple of T, where T is Tensor or convertible to Tensor, default None)

    当需输入模型为ScriptModule 或 ScriptFunction时必须提供。此参数用于确定输出的类型和形状,而不跟踪(tracing )模型的执行。

    • dynamic_axes (dict<string, dict<python:int, string>> or dict<string, list(int)>, default empty dict)

    通过以下规则设置动态的维度:

    KEY(str) - 必须是input_names或output_names指定的名称,用来指定哪个变量需要使用到动态尺寸。

    VALUE(dict or list) - 如果是一个dict,dict中的key是变量的某个维度,dict中的value是我们给这个维度取的名称。如果是一个list,则list中的元素都表示此变量的某个维度。

    具体可参考如下示例:

    class SumModule(torch.nn.Module):
        def forward(self, x):
            return torch.sum(x, dim=1)
    
    
    
    # 以动态尺寸模式导出模型
    
    torch.onnx.export(SumModule(), (torch.ones(2, 2),), "onnx.pb",
                      input_names=["x"], output_names=["sum"],
                      dynamic_axes={
                          # dict value: manually named axes
                          "x": {0: "my_custom_axis_name"},
                          # list value: automatic names
                          "sum": [0],
                      })
    
    ### 导出后的节点信息
    
    ##input
    
    input {
      name: "x"
      ...
          shape {
            dim {
              dim_param: "my_custom_axis_name"  # axis 0
            }
            dim {
              dim_value: 2  # axis 1
    ...
    
    
    ##output
    output {
      name: "sum"
      ...
          shape {
            dim {
              dim_param: "sum_dynamic_axes_1"  # axis 0
    ...
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • keep_initializers_as_inputs (bool, default None)

    NONE

    • custom_opsets (dict<str, int>, default empty dict)

    NONE

  • 相关阅读:
    什么是RBAC?
    前端mounted的使用
    vue3脚手架搭建
    macos苹果电脑清理软件有哪些?cleanmymac和腾讯柠檬哪个好
    SpringBoot启动流程分析之创建SpringApplication对象(一)
    net基于asp.net的二手商品的交易系统-二手网站-计算机毕业设计
    FS4061A(5V USB输入、双节锂电池串联应用、5v升压充电8.4v管理IC
    做SEO为什么有的网站收录很难做?
    架构之路15. 创业 - 厌倦
    VUE 配置环境变量
  • 原文地址:https://blog.csdn.net/jiong9412/article/details/125383053