• 对paddleOCR中的字符识别模型转ONNX


    paddle OCR中的模型转换成ONNX。

    转换代码:

    1. import os
    2. import sys
    3. import yaml
    4. import numpy as np
    5. import cv2
    6. import argparse
    7. import paddle
    8. from paddle import nn
    9. from argparse import ArgumentParser, RawDescriptionHelpFormatter
    10. import paddle.distributed as dist
    11. from ppocr.postprocess import build_post_process
    12. from ppocr.utils.save_load import init_model
    13. from ppocr.modeling.architectures import build_model
    14. class AttrDict(dict):
    15. """Single level attribute dict, NOT recursive"""
    16. def __init__(self, **kwargs):
    17. super(AttrDict, self).__init__()
    18. super(AttrDict, self).update(kwargs)
    19. def __getattr__(self, key):
    20. if key in self:
    21. return self[key]
    22. raise AttributeError("object has no attribute '{}'".format(key))
    23. global_config = AttrDict()
    24. default_config = {'Global': {'debug': False, }}
    25. class ArgsParser(ArgumentParser):
    26. def __init__(self):
    27. super(ArgsParser, self).__init__(
    28. formatter_class=RawDescriptionHelpFormatter)
    29. # self.add_argument("-c", "--config", default='./configs/ch_PP-OCRv2_rec_idcard.yml',
    30. # help="configuration file to use")
    31. self.add_argument("-c", "--config", default='./configs/ch_PP-OCRv2_rec.yml',
    32. help="configuration file to use")
    33. self.add_argument(
    34. "-o", "--opt", nargs='+', help="set configuration options")
    35. def parse_args(self, argv=None):
    36. args = super(ArgsParser, self).parse_args(argv)
    37. assert args.config is not None, \
    38. "Please specify --config=configure_file_path."
    39. args.opt = self._parse_opt(args.opt)
    40. return args
    41. def _parse_opt(self, opts):
    42. config = {}
    43. if not opts:
    44. return config
    45. for s in opts:
    46. s = s.strip()
    47. k, v = s.split('=')
    48. config[k] = yaml.load(v, Loader=yaml.Loader)
    49. return config
    50. def merge_config(config):
    51. """
    52. Merge config into global config.
    53. Args:
    54. config (dict): Config to be merged.
    55. Returns: global config
    56. """
    57. for key, value in config.items():
    58. if "." not in key:
    59. if isinstance(value, dict) and key in global_config:
    60. global_config[key].update(value)
    61. else:
    62. global_config[key] = value
    63. else:
    64. sub_keys = key.split('.')
    65. assert (
    66. sub_keys[0] in global_config
    67. ), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format(
    68. global_config.keys(), sub_keys[0])
    69. cur = global_config[sub_keys[0]]
    70. for idx, sub_key in enumerate(sub_keys[1:]):
    71. if idx == len(sub_keys) - 2:
    72. cur[sub_key] = value
    73. else:
    74. cur = cur[sub_key]
    75. def load_config(file_path):
    76. """
    77. Load config from yml/yaml file.
    78. Args:
    79. file_path (str): Path of the config file to be loaded.
    80. Returns: global config
    81. """
    82. merge_config(default_config)
    83. _, ext = os.path.splitext(file_path)
    84. assert ext in ['.yml', '.yaml'], "only support yaml files for now"
    85. merge_config(yaml.load(open(file_path, 'rb'), Loader=yaml.Loader))
    86. return global_config
    87. def check_device(use_gpu, use_xpu=False):
    88. """
    89. Log error and exit when set use_gpu=true in paddlepaddle
    90. cpu version.
    91. """
    92. err = "Config {} cannot be set as true while your paddle " \
    93. "is not compiled with {} ! \nPlease try: \n" \
    94. "\t1. Install paddlepaddle to run model on {} \n" \
    95. "\t2. Set {} as false in config file to run " \
    96. "model on CPU"
    97. try:
    98. if use_gpu and use_xpu:
    99. print("use_xpu and use_gpu can not both be ture.")
    100. if use_gpu and not paddle.is_compiled_with_cuda():
    101. print(err.format("use_gpu", "cuda", "gpu", "use_gpu"))
    102. sys.exit(1)
    103. if use_xpu and not paddle.device.is_compiled_with_xpu():
    104. print(err.format("use_xpu", "xpu", "xpu", "use_xpu"))
    105. sys.exit(1)
    106. except Exception as e:
    107. pass
    108. def getArgs(is_train=False):
    109. FLAGS = ArgsParser().parse_args()
    110. config = load_config(FLAGS.config)
    111. merge_config(FLAGS.opt)
    112. # check if set use_gpu=True in paddlepaddle cpu version
    113. use_gpu = config['Global']['use_gpu']
    114. use_xpu = False
    115. alg = config['Architecture']['algorithm']
    116. assert alg in [
    117. 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
    118. 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
    119. 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
    120. 'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN',
    121. 'Gestalt', 'SLANet', 'RobustScanner'
    122. ]
    123. device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
    124. check_device(use_gpu, use_xpu)
    125. device = paddle.set_device(device)
    126. config['Global']['distributed'] = dist.get_world_size() != 1
    127. return config, device
    128. class CRNN(nn.Layer):
    129. def __init__(self, config, device):
    130. super(CRNN, self).__init__()
    131. # 定义预处理参数
    132. mean = (0.5, 0.5, 0.5)
    133. std = (0.5, 0.5, 0.5)
    134. self.mean = paddle.to_tensor(mean).reshape([1, 3, 1, 1])
    135. self.std = paddle.to_tensor(std).reshape([1, 3, 1, 1])
    136. self.config = config
    137. # build post process
    138. self.post_process_class = build_post_process(config['PostProcess'],
    139. config['Global'])
    140. # build model
    141. if hasattr(self.post_process_class, 'character'):
    142. char_num = len(getattr(self.post_process_class, 'character'))
    143. if self.config['Architecture']["algorithm"] in ["Distillation",
    144. ]: # distillation model
    145. for key in self.config['Architecture']["Models"]:
    146. if self.config['Architecture']['Models'][key]['Head'][
    147. 'name'] == 'MultiHead': # for multi head
    148. out_channels_list = {}
    149. if self.config['PostProcess'][
    150. 'name'] == 'DistillationSARLabelDecode':
    151. char_num = char_num - 2
    152. out_channels_list['CTCLabelDecode'] = char_num
    153. out_channels_list['SARLabelDecode'] = char_num + 2
    154. self.config['Architecture']['Models'][key]['Head'][
    155. 'out_channels_list'] = out_channels_list
    156. else:
    157. self.config['Architecture']["Models"][key]["Head"][
    158. 'out_channels'] = char_num
    159. elif self.config['Architecture']['Head'][
    160. 'name'] == 'MultiHead': # for multi head
    161. out_channels_list = {}
    162. if self.config['PostProcess']['name'] == 'SARLabelDecode':
    163. char_num = char_num - 2
    164. out_channels_list['CTCLabelDecode'] = char_num
    165. out_channels_list['SARLabelDecode'] = char_num + 2
    166. self.config['Architecture']['Head'][
    167. 'out_channels_list'] = out_channels_list
    168. else: # base rec model
    169. self.config['Architecture']["Head"]['out_channels'] = char_num
    170. # 加载模型
    171. self.model = build_model(config['Architecture'])
    172. # load_model(config, self.model)
    173. init_model(self.config, self.model)
    174. self.model.eval()
    175. def forward(self, x):
    176. # x = paddle.transpose(x, [0,3,1,2])
    177. # x = x / 255.0
    178. # x = (x - self.mean) / self.std
    179. model_out = self.model(x)
    180. # return model_out
    181. preds_idx = model_out.argmax(axis=2, name='class').astype('float32')
    182. # preds_idx = model_out.argmax(axis=2, name='class')
    183. preds_prob = model_out.max(axis=2, name='score').astype('float32')
    184. return preds_idx, preds_prob
    185. EXPORT_ONNX = True
    186. DYNAMIC = False
    187. if __name__ == '__main__':
    188. config, device = getArgs()
    189. model_crnn = CRNN(config, device=device)
    190. # 构建输入数据images:
    191. image_path = "1.jpg"
    192. img = cv2.imread(image_path)
    193. img = cv2.resize(img, (320, 32))
    194. print('input data:', img.shape)
    195. img = img.astype(np.float32)
    196. img = img.transpose((2, 0, 1)) / 255
    197. input_data = img[np.newaxis, :]
    198. print('input data:', input_data.shape)
    199. x = paddle.to_tensor(input_data)
    200. print('input data:', x.shape)
    201. output_idx, output_prob = model_crnn(x)
    202. print('output_idx: ', output_idx)
    203. print('output_prob: ', output_prob)
    204. input_spec = paddle.static.InputSpec.from_tensor(x, name='input')
    205. onnx_save_path = "./export_onnx"
    206. if EXPORT_ONNX:
    207. onnx_model_name = onnx_save_path + "/char_recognize_20230526_v1"
    208. if DYNAMIC:
    209. input_spec = paddle.static.InputSpec(
    210. shape=[None, 32, 320, 3], dtype='float32', name='input')
    211. # ONNX模型导出
    212. paddle.onnx.export(model_crnn, onnx_model_name, input_spec=[input_spec], opset_version=11,
    213. enable_onnx_checker=True, output_spec=[output_idx, output_prob])

    转换后的网络结构绘制出来,绘制使用的工具Netron

     绘制出的起始和末尾的网络结构:

    测试ONNX的代码:

    1. '''
    2. 测试转出的onnx模型
    3. '''
    4. import cv2
    5. import numpy as np
    6. import torch
    7. import onnxruntime as rt
    8. import math
    9. import os
    10. class TestOnnx:
    11. def __init__(self, onnx_file, character_dict_path, use_space_char=True):
    12. self.sess = rt.InferenceSession(onnx_file)
    13. # 获取输入节点名称
    14. self.input_names = [input.name for input in self.sess.get_inputs()]
    15. # 获取输出节点名称
    16. self.output_names = [output.name for output in self.sess.get_outputs()]
    17. self.character = []
    18. self.character.append("blank")
    19. with open(character_dict_path, "rb") as fin:
    20. lines = fin.readlines()
    21. for line in lines:
    22. line = line.decode('utf-8').strip("\n").strip("\r\n")
    23. self.character.append(line)
    24. if use_space_char:
    25. self.character.append(" ")
    26. def resize_norm_img(self, img, image_shape=[3, 32, 320]):
    27. imgC, imgH, imgW = image_shape
    28. h = img.shape[0]
    29. w = img.shape[1]
    30. ratio = w / float(h)
    31. if math.ceil(imgH * ratio) > imgW:
    32. resized_w = imgW
    33. else:
    34. resized_w = int(math.ceil(imgH * ratio))
    35. resized_image = cv2.resize(img, (resized_w, imgH))
    36. resized_image = resized_image.astype('float32')
    37. if image_shape[0] == 1:
    38. resized_image = resized_image / 255
    39. resized_image = resized_image[np.newaxis, :]
    40. else:
    41. resized_image = resized_image.transpose((2, 0, 1)) / 255
    42. resized_image -= 0.5
    43. resized_image /= 0.5
    44. padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
    45. padding_im[:, :, 0:resized_w] = resized_image
    46. return padding_im
    47. # # 准备模型运行的feed_dict
    48. def process(self, input_names, image):
    49. feed_dict = dict()
    50. for input_name in input_names:
    51. feed_dict[input_name] = image
    52. return feed_dict
    53. def get_ignored_tokens(self):
    54. return [0]
    55. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
    56. """ convert text-index into text-label. """
    57. result_list = []
    58. ignored_tokens = self.get_ignored_tokens()
    59. batch_size = len(text_index)
    60. for batch_idx in range(batch_size):
    61. selection = np.ones(len(text_index[batch_idx]), dtype=bool)
    62. if is_remove_duplicate:
    63. selection[1:] = text_index[batch_idx][1:] != text_index[
    64. batch_idx][:-1]
    65. for ignored_token in ignored_tokens:
    66. selection &= text_index[batch_idx] != ignored_token
    67. char_list = [
    68. self.character[int(text_id)].replace('\n', '')
    69. for text_id in text_index[batch_idx][selection]
    70. ]
    71. if text_prob is not None:
    72. conf_list = text_prob[batch_idx][selection]
    73. else:
    74. conf_list = [1] * len(selection)
    75. if len(conf_list) == 0:
    76. conf_list = [0]
    77. text = ''.join(char_list)
    78. result_list.append((text, np.mean(conf_list).tolist()))
    79. return result_list
    80. def test(self, image_path):
    81. img_onnx = cv2.imread(image_path)
    82. # img_onnx = cv2.resize(img_onnx, (320, 32))
    83. # img_onnx = img_onnx.transpose((2, 0, 1)) / 255
    84. img_onnx = self.resize_norm_img(img_onnx)
    85. onnx_indata = img_onnx[np.newaxis, :, :, :]
    86. onnx_indata = torch.from_numpy(onnx_indata)
    87. # print('diff:', onnx_indata - input_data)
    88. print('image shape: ', onnx_indata.shape)
    89. onnx_indata = np.array(onnx_indata, dtype=np.float32)
    90. feed_dict = self.process(self.input_names, onnx_indata)
    91. output_onnx = self.sess.run(self.output_names, feed_dict)
    92. # print('output1 shape: ', output_onnx[0].shape)
    93. # print('output1: ', output_onnx[0])
    94. # print('output2 shape: ', output_onnx[1].shape)
    95. # print('output2: ', output_onnx[1])
    96. preds_idx = output_onnx[0]
    97. preds_prob = output_onnx[1]
    98. post_result = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
    99. if isinstance(post_result, dict):
    100. rec_info = dict()
    101. for key in post_result:
    102. if len(post_result[key][0]) >= 2:
    103. rec_info[key] = {
    104. "label": post_result[key][0][0],
    105. "score": float(post_result[key][0][1]),
    106. }
    107. print(image_path, rec_info)
    108. else:
    109. if len(post_result[0]) >= 2:
    110. # info = post_result[0][0] + "\t" + str(post_result[0][1])
    111. info = post_result[0][0]
    112. print(image_path, info)
    113. if __name__=='__main__':
    114. image_dir = "./sample/img"
    115. onnx_file = './export_onnx/char_recognize_20230526_v1.onnx'
    116. character_dict_path = './all_label_num_20230517.txt'
    117. testobj = TestOnnx(onnx_file, character_dict_path)
    118. files = os.listdir(image_dir)
    119. for file in files:
    120. image_path = os.path.join(image_dir, file)
    121. result = testobj.test(image_path)

    模型转换结束。 

  • 相关阅读:
    Mybatis--动态SQL
    Gartner“客户之声”最高分,用户体验成中国数据库一大突破口
    标准IDOC同步物料
    21-SpringSecurity
    深度学习框架【MxNet】的安装
    蓝桥杯每日一题0223.10.23
    淘宝数据采集接口
    【物联网】802.15.4简介
    FFmpeg入门详解之103:FFmpeg Nginx VLC打造M3U8直播点播
    【软件工程】【第一章概述】d2
  • 原文地址:https://blog.csdn.net/qq_22764813/article/details/133787584