• 使用protobuf解析Onnx文件


    使用OpenCV加载Onnx推理的时候,无法获取到Onnx的网络输入大小,并且对推理速度要求不高不需要使用TensorRT的时候,如何才能得知Onnx的一些必要的信息,OpenCV没有提供接口,只能自己从Onnx文件中解析了。

    Onnx文件是使用protobuf序列化后的二进制数据,想要读取里面的信息需要使用protobuf将其反序列化为对象才行。

    第一步:

    编译 protobuf,protocolbuffers/protobuf: Protocol Buffers - Google's data interchange format (github.com)

    使用cmake生成vs工程直接编译即可。

    第二步:

    使用protoc命令,将编写好的proto文件生成C++类定义文件。

    .\protoc.exe --cpp_out=. ./onnx.proto

    得到 onnx.pb.h 和 onnx.pb.cpp 两个文件,将其加入到工程中。

    Onnx.proto 文件可以从 onnx/onnx: Open standard for machine learning interoperability (github.com) 找到。

    使用编译好的库文件编译的过程中,如果出现 无法解析的外部符号 "class google::protobuf::internal::ExplicitlyConstructed fixed_address_empty_string" 这个错误,添加预处理宏定义 PROTOBUF_USE_DLLS 即可。

    关键代码:

    加载反序列化onnx文件

    1.     ifstream fin("model.onnx", std::ios::in | std::ios::binary);
    2.     onnx::ModelProto onnx_model;
    3.     onnx_model.ParseFromIstream(&fin);

    打印一些信息   

    1. std::cout << "ir_version: " << onnx_model.ir_version() << std::endl;
    2.     std::cout << "opset_import_size: " << onnx_model.opset_import_size() << std::endl;
    3.     std::cout << "OperatorSetIdProto domain: " << onnx_model.opset_import(0).domain() << std::endl;
    4.     std::cout << "OperatorSetIdProto version: " << onnx_model.opset_import(0).version() << std::endl;
    5.     std::cout << "producer_name: " << onnx_model.producer_name() << std::endl;
    6.     std::cout << "producer_version: " << onnx_model.producer_version() << std::endl;
    7.     std::cout << "domain: " << onnx_model.domain() << std::endl;
    8.     std::cout << "model_version: " << onnx_model.model_version() << std::endl;
    9.     std::cout << "doc_string: " << onnx_model.doc_string() << std::endl;

    输入节点的个数

     

    onnx_model.graph().input_size()

    输入节点的名称

     

     onnx_model.graph().input(0).name()

    输入输出节点的数据类型

    onnx_model.graph().input(0).type().tensor_type().elem_type()

    返回的是int类型,与实际数据类型对应关系见:onnx/onnx.proto at main · onnx/onnx (github.com) 文件中的 DataType 枚举类型。

    输入节点的输入维度

    1. int dim_size = onnx_model.graph().input(0).type().tensor_type().shape().dim_size();
    2. for (int i = 0; i < dim_size; i++)
    3. {
    4.     onnx_model.graph().input(0).type().tensor_type().shape().dim().Get(i).dim_value();
    5. }

    封装类 OnnxInfo

    OnnxInfo.h

    1. #ifdef __cplusplus
    2. extern "C" {
    3. #endif
    4.     DLL_API uint64_t onnx_load(const char* path);
    5.     DLL_API void onnx_close(uint64_t ptr_addr);
    6.     DLL_API int onnx_get_input_count(uint64_t ptr_addr);
    7.     DLL_API int onnx_get_output_count(uint64_t ptr_addr);
    8.     DLL_API const char* onnx_get_input_name(uint64_t ptr_addr, int input_index);
    9.     DLL_API const char* onnx_get_output_name(uint64_t ptr_addr, int output_index);
    10.     DLL_API const char* onnx_get_input_data_type(uint64_t ptr_addr, int input_index);
    11.     DLL_API const char* onnx_get_output_data_type(uint64_t ptr_addr, int output_index);
    12.     DLL_API int onnx_get_input_dims(uint64_t ptr_addr, int input_index, int* dims);
    13.     DLL_API int onnx_get_output_dims(uint64_t ptr_addr, int output_index, int* dims);
    14. #ifdef __cplusplus
    15. }
    16. #endif

    OnnxInfo.cpp

    1. #define _DLL_INTERNEL_
    2. #include "OnnxInfo.h"
    3. #include "onnx.pb.h"
    4. #include
    5. #include
    6. #include
    7. using namespace std;
    8. #ifdef __cplusplus
    9. extern "C" {
    10. #endif
    11. namespace {
    12. char g_string[256];
    13. vector> g_models;
    14. void g_release_ptr(uint64_t ptr_addr)
    15. {
    16. auto iter = g_models.begin();
    17. while (iter != g_models.end())
    18. {
    19. if (ptr_addr == (uint64_t)iter->get())
    20. {
    21. g_models.erase(iter);
    22. break;
    23. }
    24. iter++;
    25. }
    26. }
    27. shared_ptr g_get_ptr(uint64_t ptr_addr)
    28. {
    29. for (auto& ptr : g_models)
    30. {
    31. if (ptr_addr == (uint64_t)ptr.get())
    32. {
    33. return ptr;
    34. }
    35. }
    36. return nullptr;
    37. }
    38. const char* g_get_data_type_name_by_id(int data_type_id)
    39. {
    40. auto dt = (onnx::TensorProto_DataType)data_type_id;
    41. switch (dt)
    42. {
    43. case onnx::TensorProto_DataType_UNDEFINED:
    44. return "UNDEFINED";
    45. case onnx::TensorProto_DataType_FLOAT:
    46. return "FLOAT";
    47. case onnx::TensorProto_DataType_UINT8:
    48. return "UINT8";
    49. case onnx::TensorProto_DataType_INT8:
    50. return "FLOAT";
    51. case onnx::TensorProto_DataType_UINT16:
    52. return "UINT16";
    53. case onnx::TensorProto_DataType_INT16:
    54. return "INT16";
    55. case onnx::TensorProto_DataType_INT32:
    56. return "INT32";
    57. case onnx::TensorProto_DataType_INT64:
    58. return "INT64";
    59. case onnx::TensorProto_DataType_STRING:
    60. return "STRING";
    61. case onnx::TensorProto_DataType_BOOL:
    62. return "BOOL";
    63. case onnx::TensorProto_DataType_FLOAT16:
    64. return "FLOAT16";
    65. case onnx::TensorProto_DataType_DOUBLE:
    66. return "DOUBLE";
    67. case onnx::TensorProto_DataType_UINT32:
    68. return "UINT32";
    69. case onnx::TensorProto_DataType_UINT64:
    70. return "UINT64";
    71. case onnx::TensorProto_DataType_COMPLEX64:
    72. return "COMPLEX64";
    73. case onnx::TensorProto_DataType_COMPLEX128:
    74. return "COMPLEX128";
    75. case onnx::TensorProto_DataType_BFLOAT16:
    76. return "BFLOAT16";
    77. default:
    78. return "";
    79. }
    80. }
    81. }
    82. /// 加载onnx
    83. uint64_t onnx_load(const char* path)
    84. {
    85. ifstream fin(path, std::ios::in | std::ios::binary);
    86. auto onnx_model_ptr = make_shared();
    87. bool bret = onnx_model_ptr->ParseFromIstream(&fin);
    88. fin.close();
    89. if (!bret) return 0;
    90. uint64_t ptr = (uint64_t)onnx_model_ptr.get();
    91. g_models.push_back(move(onnx_model_ptr));
    92. return ptr;
    93. }
    94. /// 关闭onnx
    95. void onnx_close(uint64_t ptr_addr)
    96. {
    97. g_release_ptr(ptr_addr);
    98. }
    99. /// 获取输入节点的个数
    100. int onnx_get_input_count(uint64_t ptr_addr)
    101. {
    102. auto ptr = g_get_ptr(ptr_addr);
    103. if (!ptr.get()) return -1;
    104. if (!ptr->has_graph()) return -2;
    105. return ptr->graph().input_size();
    106. }
    107. /// 获取输出节点的个数
    108. int onnx_get_output_count(uint64_t ptr_addr)
    109. {
    110. auto ptr = g_get_ptr(ptr_addr);
    111. if (!ptr.get()) return -1;
    112. if (!ptr->has_graph()) return -2;
    113. return ptr->graph().output_size();
    114. }
    115. /// 获取输入节点的名称
    116. const char* onnx_get_input_name(uint64_t ptr_addr, int input_index)
    117. {
    118. auto ptr = g_get_ptr(ptr_addr);
    119. if (!ptr.get()) return "";
    120. if (!ptr->has_graph()) return "";
    121. int input_size = ptr->graph().input_size();
    122. if (input_index >= input_size || input_index < 0) return "";
    123. auto input = ptr->graph().input(input_index);
    124. sprintf_s(g_string, sizeof(g_string), input.name().c_str());
    125. return g_string;
    126. }
    127. /// 获取输出节点的名称
    128. const char* onnx_get_output_name(uint64_t ptr_addr, int output_index)
    129. {
    130. auto ptr = g_get_ptr(ptr_addr);
    131. if (!ptr.get()) return "";
    132. if (!ptr->has_graph()) return "";
    133. int output_size = ptr->graph().output_size();
    134. if (output_index >= output_size || output_index < 0) return "";
    135. auto output = ptr->graph().output(output_index);
    136. sprintf_s(g_string, sizeof(g_string), output.name().c_str());
    137. return g_string;
    138. }
    139. /// 获取输入节点的数据类型
    140. const char* onnx_get_input_data_type(uint64_t ptr_addr, int input_index)
    141. {
    142. auto ptr = g_get_ptr(ptr_addr);
    143. if (!ptr.get()) return "";
    144. if (!ptr->has_graph()) return "";
    145. int input_size = ptr->graph().input_size();
    146. if (input_index >= input_size || input_index < 0) return "";
    147. auto input = ptr->graph().input(input_index);
    148. auto type_id = input.type().tensor_type().elem_type();
    149. return g_get_data_type_name_by_id(type_id);
    150. }
    151. /// 获取输出节点的数据类型
    152. const char* onnx_get_output_data_type(uint64_t ptr_addr, int output_index)
    153. {
    154. auto ptr = g_get_ptr(ptr_addr);
    155. if (!ptr.get()) return "";
    156. if (!ptr->has_graph()) return "";
    157. int output_size = ptr->graph().output_size();
    158. if (output_index >= output_size || output_index < 0) return "";
    159. auto output = ptr->graph().output(output_index);
    160. auto type_id = output.type().tensor_type().elem_type();
    161. return g_get_data_type_name_by_id(type_id);
    162. }
    163. /// 获取输入节点的维数
    164. int onnx_get_input_dims(uint64_t ptr_addr, int input_index, int* dims)
    165. {
    166. auto ptr = g_get_ptr(ptr_addr);
    167. if (!ptr.get()) return -1;
    168. if (!ptr->has_graph()) return -2;
    169. int input_size = ptr->graph().input_size();
    170. if (input_index >= input_size || input_index < 0) return -3;
    171. auto input = ptr->graph().input(input_index);
    172. int dim_size = input.type().tensor_type().shape().dim_size();
    173. if (dims) for (int i = 0; i < dim_size; i++)
    174. {
    175. dims[i] = input.type().tensor_type().shape().dim().Get(i).dim_value();
    176. }
    177. return dim_size;
    178. }
    179. /// 获取输出节点的维数
    180. int onnx_get_output_dims(uint64_t ptr_addr, int output_index, int* dims)
    181. {
    182. auto ptr = g_get_ptr(ptr_addr);
    183. if (!ptr.get()) return -1;
    184. if (!ptr->has_graph()) return -2;
    185. int output_size = ptr->graph().output_size();
    186. if (output_index >= output_size || output_index < 0) return -3;
    187. auto output = ptr->graph().output(output_index);
    188. int dim_size = output.type().tensor_type().shape().dim_size();
    189. if (dims) for (int i = 0; i < dim_size; i++)
    190. {
    191. dims[i] = output.type().tensor_type().shape().dim().Get(i).dim_value();
    192. }
    193. return dim_size;
    194. }
    195. #ifdef __cplusplus
    196. }
    197. #endif

    调用方法

    1. int main()
    2. {
    3. const char* path = "myunet.onnx";
    4. auto adr = onnx_load(path);
    5. int ic = onnx_get_input_count(adr);
    6. int oc = onnx_get_output_count(adr);
    7. cout << onnx_get_input_name(adr, 0) << endl;
    8. cout << onnx_get_output_name(adr, 0) << endl;
    9. cout << onnx_get_input_data_type(adr, 0) << endl;
    10. cout << onnx_get_output_data_type(adr, 0) << endl;
    11. int dim_size;
    12. int dims[8];
    13. cout << "input dim size: " << onnx_get_input_dims(adr, 0, 0) << endl;
    14. cout << "input dim size2: " << (dim_size = onnx_get_input_dims(adr, 0, dims)) << endl;
    15. for (int i = 0; i < dim_size; i++)
    16. {
    17. cout << dims[i] << " ";
    18. }
    19. cout << endl;
    20. cout << "output dim size: " << onnx_get_output_dims(adr, 0, 0) << endl;
    21. cout << "output dim size2: " << (dim_size = onnx_get_output_dims(adr, 0, dims)) << endl;
    22. for (int i = 0; i < dim_size; i++)
    23. {
    24. cout << dims[i] << " ";
    25. }
    26. cout << endl;
    27. onnx_close(adr);
    28. cin.ignore();
    29. return 0;
    30. }

  • 相关阅读:
    Kafka干货之「零拷贝」
    基于JSP的九宫格日志网站
    Python与数据分析--Pandas-1
    计算机系统的层次结构
    tcpdump抓包详解
    【API篇】十一、Flink水位线传递与迟到数据处理
    STM32驱动AHT10&OLED显示温湿度
    FFmpeg开发笔记(九)Linux交叉编译Android的x265库
    Chapter9 : De Novo Molecular Design with Chemical Language Models
    第十七节 huggingface的trainner的断点续训的Demo(resume)
  • 原文地址:https://blog.csdn.net/Ango_/article/details/127790921