• onnx手术刀(ONNX-GraphSurgeon):对模型的输入端进行增,删,改操作(一)


    一、引言

    ONNX(Open Neural Network Exchange)是一种开放格式,用于表示深度学习模型。它旨在促进不同框架之间的模型互操作性。然而,在实际应用中,我们可能需要对模型进行定制和优化,以满足特定场景的需求。ONNX-GraphSurgeon正是为此而生,它允许开发者轻松地修改和优化ONNX模型。

    二、ONNX-GraphSurgeon简介

    ONNX-GraphSurgeon是一个Python库,用于操作ONNX计算图。它提供了丰富的API,支持对计算图进行增删改查等操作。以下是ONNX-GraphSurgeon的主要特点:

    1. 灵活性:可以轻松地修改计算图结构,如添加、删除、替换节点和边。
    2. 高效性:支持在计算图中进行层融合、模型剪枝等优化操作。
    3. 易用性:提供了简洁的API,便于开发者快速上手。
    官方代码地址:
    https://github.com/NVIDIA/TensorRT/tree/release/10.1/tools/onnx-graphsurgeon

    三、安装ONNX-GraphSurgeon

    在开始使用ONNX-GraphSurgeon之前,需要先安装以下依赖:

    Python 3.6及以上版本

    ONNX 1.6.0及以上版本

    numpy

    安装命令如下:

    pip install onnx-graphsurgeon

    四、对onnx输入端进行处理

    1、onnx为啥需要剪切呢?

    你以为的模型导出的onnx,

    实际导出的onnx.

    使用ONNX-GraphSurgeon 剪切后的onnx.

    2、生成模型

    建立一个模型:output = ReLU((A * X^T) + B) (.) C + D

    1. #!/usr/bin/env python3
    2. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
    3. #
    4. # Licensed under the Apache License, Version 2.0 (the "License");
    5. # you may not use this file except in compliance with the License.
    6. # You may obtain a copy of the License at
    7. #
    8. # http://www.apache.org/licenses/LICENSE-2.0
    9. #
    10. # Unless required by applicable law or agreed to in writing, software
    11. # distributed under the License is distributed on an "AS IS" BASIS,
    12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13. # See the License for the specific language governing permissions and
    14. # limitations under the License.
    15. #
    16. import onnx_graphsurgeon as gs
    17. import numpy as np
    18. import onnx
    19. print("Graph.layer Help:\n{}".format(gs.Graph.layer.__doc__))
    20. # We can use `Graph.register()` to add a function to the Graph class. Later, we can invoke the function
    21. # directly on instances of the graph, e.g., `graph.add(...)`
    22. @gs.Graph.register()
    23. def add(self, a, b):
    24. # The Graph.layer function creates a node, adds inputs and outputs to it, and finally adds it to the graph.
    25. # It returns the output tensors of the node to make it easy to chain.
    26. # The function will append an index to any strings provided for inputs/outputs prior
    27. # to using them to construct tensors. This will ensure that multiple calls to the layer() function
    28. # will generate distinct tensors. However, this does NOT guarantee that there will be no overlap with
    29. # other tensors in the graph. Hence, you should choose the prefixes to minimize the possibility of
    30. # collisions.
    31. return self.layer(op="Add", inputs=[a, b], outputs=["add_out_gs"])
    32. @gs.Graph.register()
    33. def mul(self, a, b):
    34. return self.layer(op="Mul", inputs=[a, b], outputs=["mul_out_gs"])
    35. @gs.Graph.register()
    36. def gemm(self, a, b, trans_a=False, trans_b=False):
    37. attrs = {"transA": int(trans_a), "transB": int(trans_b)}
    38. return self.layer(op="Gemm", inputs=[a, b], outputs=["gemm_out_gs"], attrs=attrs)
    39. # You can also specify a set of opsets when regsitering a function.
    40. # By default, the function is registered for all opsets lower than Graph.DEFAULT_OPSET
    41. @gs.Graph.register(opsets=[11])
    42. def relu(self, a):
    43. return self.layer(op="Relu", inputs=[a], outputs=["act_out_gs"])
    44. # Note that the same function can be defined in different ways for different opsets.
    45. # It will only be called if the Graph's opset matches one of the opsets for which the function is registered.
    46. # Hence, for the opset 11 graph used in this example, the following function will never be used.
    47. @gs.Graph.register(opsets=[1])
    48. def relu(self, a):
    49. raise NotImplementedError("This function has not been implemented!")
    50. ##########################################################################################################
    51. # The functions registered above greatly simplify the process of building the graph itself.
    52. graph = gs.Graph(opset=11)
    53. # Generates a graph which computes:
    54. # output = ReLU((A * X^T) + B) (.) C + D
    55. X = gs.Variable(name="X", shape=(64, 64), dtype=np.float32)
    56. graph.inputs = [X]
    57. # axt = (A * X^T)
    58. # Note that we can use NumPy arrays directly (e.g. Tensor A),
    59. # instead of Constants. These will automatically be converted to Constants.
    60. A = np.ones(shape=(64, 64), dtype=np.float32)
    61. axt = graph.gemm(A, X, trans_b=True)
    62. # dense = ReLU(axt + B)
    63. B = np.ones((64, 64), dtype=np.float32) * 0.5
    64. dense = graph.relu(*graph.add(*axt, B))
    65. # output = dense (.) C + D
    66. # If a Tensor instance is provided (e.g. Tensor C), it will not be modified at all.
    67. # If you prefer to set the exact names of tensors in the graph, you should
    68. # construct tensors manually instead of passing strings or NumPy arrays.
    69. C = gs.Constant(name="C", values=np.ones(shape=(64, 64), dtype=np.float32))
    70. D = np.ones(shape=(64, 64), dtype=np.float32)
    71. graph.outputs = graph.add(*graph.mul(*dense, C), D)
    72. # Finally, we need to set the output datatype to make this a valid ONNX model.
    73. # In our case, all the data types are float32.
    74. for out in graph.outputs:
    75. out.dtype = np.float32
    76. onnx.save(gs.export_onnx(graph), "model.onnx")

     
    

    3、在初始结点处增加操作

    在Gemm操作前添加减均值,除均差的操作。

    1. import onnx_graphsurgeon as gs
    2. import numpy as np
    3. import onnx
    4. #增加减均值,除方差的操作 X^T) + B) (.) C + D
    5. graph = gs.import_onnx(onnx.load("model.onnx"))
    6. tamps = graph.tensors()
    7. X = gs.Variable(name="X", shape=(64, 64), dtype=np.float32)
    8. # 定义均值和方差
    9. mean_value = np.array([0.5], dtype=np.float32) # 替换 YOUR_MEAN_VALUE
    10. std_value = np.array([0.2], dtype=np.float32) # 替换 YOUR_STD_VALUE
    11. mean = gs.Constant(name="mean", values=mean_value)
    12. std = gs.Constant(name="std", values=std_value)
    13. # 创建减均值和除方差的节点
    14. sub_output = gs.Variable(name="X_minus_mean", shape=(64, 64), dtype=np.float32)
    15. div_output = gs.Variable(name="X_normalized", shape=(64, 64), dtype=np.float32)
    16. sub_node = gs.Node(op="Sub", inputs=[X, mean], outputs=[sub_output])
    17. div_node = gs.Node(op="Div", inputs=[sub_output, std], outputs=[div_output])
    18. # 将新创建的节点添加到图中
    19. graph.nodes.extend([sub_node, div_node])
    20. first_node = [node for node in graph.nodes if node.op == "Gemm"][0]
    21. first_node.inputs[1] = div_node.outputs[0]
    22. # 清理和顶排序
    23. graph.cleanup().toposort()
    24. onnx.save(gs.export_onnx(graph), "model_add.onnx")

    4、修改结点的输入

    将输入X改成Y。

    1. import onnx_graphsurgeon as gs
    2. import numpy as np
    3. import onnx
    4. # output = ReLU((A * X^T) + B) (.) C + D
    5. graph = gs.import_onnx(onnx.load("model.onnx"))
    6. tamps = graph.tensors()
    7. # modify the input from X to Y
    8. Y = gs.Variable(name="Y", shape=(64, 64), dtype=np.float32)
    9. graph.inputs = [Y]
    10. first_node = [node for node in graph.nodes if node.op == "Gemm"][0]
    11. first_node.inputs[1] = Y
    12. # 清理和顶排序
    13. graph.cleanup().toposort()
    14. onnx.save(gs.export_onnx(graph), "model_modify.onnx")

    5、删除结点

    删除Gemm操作

    1. import onnx_graphsurgeon as gs
    2. import numpy as np
    3. import onnx
    4. # output = ReLU((A * X^T) + B) (.) C + D
    5. graph = gs.import_onnx(onnx.load("model.onnx"))
    6. tamps = graph.tensors()
    7. #delete Gemm
    8. first_node = [node for node in graph.nodes if node.op == "Add"][0]
    9. first_node.inputs[0] = tamps["X"]
    10. # 清理和顶排序
    11. graph.cleanup().toposort()
    12. onnx.save(gs.export_onnx(graph), "model_delete.onnx")

    总结:

    ONNX GraphSurgeon 是一个强大的深度学习模型优化工具,它可以帮助我们提高模型的推理速度和资源利用率。通过合理地使用 ONNX GraphSurgeon,我们可以使深度学习模型在各种硬件平台上发挥出更好的性能。

    欢迎关注我的公众号auto_driver_ai(Ai fighting), 第一时间获取更新内容。

  • 相关阅读:
    初学者需掌握的12条基本 Linux 命令
    深度优先搜索遍历与广度优先搜索遍历
    GreenPlum在线扩容工具GPExpan实战
    C#实现一个万物皆可排序的队列
    Qt pro文件中 CONFIG += debug 作用
    LeakyReLU激活函数
    Tmux 使用教程
    win10安装.net3.5
    Ceph入门到精通-CEPH故障以其处理方法
    Linux 并发与竞争(二)
  • 原文地址:https://blog.csdn.net/laukal/article/details/140408682