网上几乎所有的相关帖子都没注意到一个问题,就是要满足拓扑排序,不满足拓扑排序后面的check_model是会报错的。(其实就是要满足节点的先后顺序)
以修改ViT中的Concat节点为例:
- import onnx
-
- onnx_model = onnx.load("./vit.onnx")
- graph = onnx_model.graph
- node = graph.node
-
- orig_len = len(node)
- node_name = '/Concat'
- node_index = -1
- for i in range(len(node)):
- if node_name == node[i].name:
- node_index = i
-
- inputs_list = []
- outputs_list = []
- for item in node[node_index].input:
- inputs_list.append(item)
- for item in node[node_index].output:
- outputs_list.append(item)
-
- offset = len(inputs_list)+1
-
- for i in range(len(graph.node)):
- T_index = len(node)-i
- if T_index == len(node):
- T_index = T_index-1
- if T_index>node_index and T_index<orig_len:
- T_node = node[T_index]
- if (T_index+offset) < len(node):
- graph.node.remove(node[T_index+offset])
- graph.node.insert(T_index+offset, T_node)
-
- for i in range(len(inputs_list)):
- index = node_index+i
- attr = onnx.helper.make_attribute('perm', [1, 0, 2])
- new_scale_node = onnx.helper.make_node(
- "Transpose",
- inputs=[inputs_list[i]],
- outputs=[inputs_list[i]+'_add']
- )
- old_node = node[index]
- graph.node.remove(old_node)
- graph.node.insert(index, new_scale_node)
- node[index].attribute.insert(0, attr)
- node[index].name = 'Transpose_'+str(index)
-
- index = node_index+len(inputs_list)
- attr = onnx.helper.make_attribute('axis', 0) #添加属性
- new_scale_node = onnx.helper.make_node(
- 'Concat',
- inputs=[item+'_add' for item in inputs_list],
- outputs=['output_changed']
- ) # 新建新节点
- old_scale_node = node[index]
- graph.node.remove(old_scale_node) # 删除旧节点
- graph.node.insert(index, new_scale_node) # 插入新节点
- node[index].attribute.insert(0, attr)
- node[index].name = node_name
-
- index = node_index+len(inputs_list)+1
- attr = onnx.helper.make_attribute('perm', [1, 0, 2])
- new_scale_node = onnx.helper.make_node(
- "Transpose",
- inputs=['output_changed'],
- outputs=[item for item in outputs_list]
- )
- old_node = node[index]
- graph.node.remove(old_node)
- graph.node.insert(index, new_scale_node)
- node[index].attribute.insert(0, attr)
- node[index].name = 'Transpose_'+str(index)
-
- graph = onnx.helper.make_graph(graph.node, graph.name, graph.input, graph.output, graph.initializer)
- info_model = onnx.helper.make_model(graph)
- # onnx_model = onnx.shape_inference.infer_shapes(info_model)
-
- onnx.checker.check_model(onnx_model)
- onnx.save(onnx_model, './changed_vit.onnx')