• 修改ONNX模型节点


    网上几乎所有的相关帖子都没注意到一个问题,就是要满足拓扑排序,不满足拓扑排序后面的check_model是会报错的。(其实就是要满足节点的先后顺序)

    以修改ViT中的Concat节点为例:

    1. import onnx
    2. onnx_model = onnx.load("./vit.onnx")
    3. graph = onnx_model.graph
    4. node = graph.node
    5. orig_len = len(node)
    6. node_name = '/Concat'
    7. node_index = -1
    8. for i in range(len(node)):
    9. if node_name == node[i].name:
    10. node_index = i
    11. inputs_list = []
    12. outputs_list = []
    13. for item in node[node_index].input:
    14. inputs_list.append(item)
    15. for item in node[node_index].output:
    16. outputs_list.append(item)
    17. offset = len(inputs_list)+1
    18. for i in range(len(graph.node)):
    19. T_index = len(node)-i
    20. if T_index == len(node):
    21. T_index = T_index-1
    22. if T_index>node_index and T_index<orig_len:
    23. T_node = node[T_index]
    24. if (T_index+offset) < len(node):
    25. graph.node.remove(node[T_index+offset])
    26. graph.node.insert(T_index+offset, T_node)
    27. for i in range(len(inputs_list)):
    28. index = node_index+i
    29. attr = onnx.helper.make_attribute('perm', [1, 0, 2])
    30. new_scale_node = onnx.helper.make_node(
    31. "Transpose",
    32. inputs=[inputs_list[i]],
    33. outputs=[inputs_list[i]+'_add']
    34. )
    35. old_node = node[index]
    36. graph.node.remove(old_node)
    37. graph.node.insert(index, new_scale_node)
    38. node[index].attribute.insert(0, attr)
    39. node[index].name = 'Transpose_'+str(index)
    40. index = node_index+len(inputs_list)
    41. attr = onnx.helper.make_attribute('axis', 0) #添加属性
    42. new_scale_node = onnx.helper.make_node(
    43. 'Concat',
    44. inputs=[item+'_add' for item in inputs_list],
    45. outputs=['output_changed']
    46. ) # 新建新节点
    47. old_scale_node = node[index]
    48. graph.node.remove(old_scale_node) # 删除旧节点
    49. graph.node.insert(index, new_scale_node) # 插入新节点
    50. node[index].attribute.insert(0, attr)
    51. node[index].name = node_name
    52. index = node_index+len(inputs_list)+1
    53. attr = onnx.helper.make_attribute('perm', [1, 0, 2])
    54. new_scale_node = onnx.helper.make_node(
    55. "Transpose",
    56. inputs=['output_changed'],
    57. outputs=[item for item in outputs_list]
    58. )
    59. old_node = node[index]
    60. graph.node.remove(old_node)
    61. graph.node.insert(index, new_scale_node)
    62. node[index].attribute.insert(0, attr)
    63. node[index].name = 'Transpose_'+str(index)
    64. graph = onnx.helper.make_graph(graph.node, graph.name, graph.input, graph.output, graph.initializer)
    65. info_model = onnx.helper.make_model(graph)
    66. # onnx_model = onnx.shape_inference.infer_shapes(info_model)
    67. onnx.checker.check_model(onnx_model)
    68. onnx.save(onnx_model, './changed_vit.onnx')

  • 相关阅读:
    「项目管理」如何做好项目进度管理计划?
    项目配置vue.config jsconfig babel.config .prettierc .env .eslintrc
    LeetCode1137第N个泰波那契数
    Python 搭建 FastAPI 项目
    JQuery实现图片切换(自动切换+手动切换)
    b站黑马JavaScript的Ajax案例代码——图书管理案例
    java基于微信小程序的大学生个人家庭理财产品 uniapp小程序
    向上管理读书笔记
    vue2_路由02_嵌套/多级路由
    Linux命令(83)之cut
  • 原文地址:https://blog.csdn.net/bule_sky_wait_me/article/details/132872998