码农知识堂 - 1000bd
  •   Python
  •   PHP
  •   JS/TS
  •   JAVA
  •   C/C++
  •   C#
  •   GO
  •   Kotlin
  •   Swift
  • EasyOCR 识别模型训练


    0. 开始之前

    EasyOCR 中使用的神经网络模型在每个阶段会不同基于开源的项目:数据集整合、数据集训练、模型使用。分别对应三种不同的框架。
    训练数据生成:
    GitHub - Belval/TextRecognitionDataGenerator: A synthetic data generator for text recognition
    训练数据转换:
    GitHub - DaveLogs/TRDG2DTRB: Convert TextRecognitionDataGenerator's result data to deep-text-recognition-benchmark's input data.
    训练和部署模型:
    https://github.com/clovaai/deep-text-recognition-benchmark
    使用用户学习模型:
    GitHub - JaidedAI/EasyOCR: Ready-to-use OCR with 80+ supported languages and all popular writing scripts including Latin, Chinese, Arabic, Devanagari, Cyrillic and etc.

    1. 创建训练数据

    训练数据生成步骤将使用一个名为 TextRecognitionDataGenerator 的开源项目。
    参考: https://blog.csdn.net/leiwuhen92/article/details/126419244  文本识别数据生成器-TextRecognitionDataGenerator
    trdg -c 2000000 -w 5 -f 64 -k 5生成训练数据2000000条:
    下一步是进行一个简单的数据转换过程,因为本文中使用TextRecognitionDataGenerator项目生成的学习数据不是deep-text-recognition-benchmark项目学习 所需的数据结构。

    2. 学习数据转换

    使用TextRecognitionDataGenerator项目生成的学习数据不是deep-text-recognition-benchmark项目学习 所需的数据结构。需要进行转换
    跳转中...

    2.1、项目安装

    $ git clone https://github.com/DaveLogs/TRDG2DTRB.git

    2.2、数据转换

    输入数据结构:
    执行命令进行转换:
    python3 convert.py  --input_path /home/ocr/  --output_path ./output

    输出:

    生成的数据由图像文件列表和 gt.txt 文件组成,其中存储了每个图像文件的标签。
    输出数据结构:
         
    原始图片的命名是有要求的:图片内容_index编号.后缀
    像4051.jpg这种格式的经过转换后得到的gt.txt如下,不是我们想要的
    相关代码逻辑如下:

    3. 训练模型

    需要借助deep-text-recognition-benchmark的开源项目。

    3.1、项目安装

    1. # 下载源代码
    2. $ git clone https://github.com/clovaai/deep-text-recognition-benchmark.git
    3. # 搭建开发环境
    4. $ pip3 install torch torchvision
    5. $ pip3 install lmdb pillow nltk natsort
    6. $ pip3 install fire

    3.2、准备阶段

    准备用于神经网络训练的训练数据和微调学习所需的预训练模型。

    3.2.1、训练数据

    3.2.2、将训练数据转换为lmdb格式

    在deep-text-recognition-benchmark项目中使用以下命令语法将其转换为lmdb格式以供实际学习时使用。
    1. # deep-text-recognition-benchmark 从项目根运行
    2. (venv) $ python3 create_lmdb_dataset.py \
    3.         --inputPath /home/TRDG2DTRB/output/ \
    4.         --gtFile /home/TRDG2DTRB/output/gt.txt \
    5.         --outputPath result/

    至此,准备训练数据的一系列过程就结束了。
    为了提高学习性能,将训练和验证的训练数据分别分为MJ和ST来构建数据,训练时设置batch_ratio来学习MJ和ST数据以适当的比例。

    3.2.3、准备预训练模型

    下载学习模型 跳转中...icon-default.png?t=N7T8https://link.zhihu.com/?target=https%3A//github.com/clovaai/deep-text-recognition-benchmark%23run-demo-with-pretrained-model 下载与实际 EasyOCR 中使用的基本模型具有相同网络结构(' None-VGG-BiLSTM-CTC ')的预训练模型。

    3.2.4、项目和预模型正常运行的确认

    让我们使用以下语法测试deep-text-recognition-benchmark项目是否与下载的模型正常工作。
    # demo.py中可查看参数及其定义
     
    1. python3 demo.py \
    2. --Transformation None \
    3. --FeatureExtraction VGG \
    4. --SequenceModeling BiLSTM \
    5. --Prediction CTC \
    6. --image_folder demo_image/ \
    7. --saved_model None-VGG-BiLSTM-CTC.pth

    3.3、训练模型

    训练数据和学习所需的预训练模型(None-VGG-BiLSTM-CTC.pth )都准备好了,就可以使用deep-text-recognition-benchmark项目提供的以下命令语法开始学习。
    # train.py中查看参数及其定义
     
    1. python3 train.py --train_data lmdb/training \
    2. --valid_data lmdb/validation \
    3. --select_data MJ-ST \
    4. --batch_ratio 0.5-0.5 \
    5. --Transformation None \
    6. --FeatureExtraction VGG \
    7. --SequenceModeling BiLSTM \
    8. --Prediction CTC \
    9. --saved_model None-VGG-BiLSTM-CTC.pth \
    10. --num_iter 2000 \
    11. --valInterval 20 \
    12. --FT

    上述命令语法的简要说明如下。

    • --train_data : 训练数据中训练的数据路径
    • --valid_data : 训练数据之间验证的数据路径
    • --select_data : 选择训练数据(默认为MJ-ST,即MJ和ST作为训练数据)
    • --batch_ratio:为批次中的每个选定数据分配比率
    • --Transformation:选择要使用的转换模块。['无','TPS']
    • --FeatureExtraction : 选择要使用的 FeatureExtraction 模块,['RCNN'、'ResNet'、'VGG']
    • --SequenceModeling:选择要使用的 SequenceModeling 模块。['无','BiLSTM']
    • --Prediction:选择要使用的预测模块。['Attn', 'CTC']
    • --saved_model : 用于微调学习的预训练模型的存储位置
    • --num_iter: 训练迭代次数,默认300000
    • --valInterval:每次检验之间的时间间隔,默认2000
    • --FT : 是否学习微调
    • --lr:学习率,对于 Adadelta,默认 = 1.0
    • --batch_max_length:最大标签长度,默认值25
    • --imgH:输入图像的高度,默认32      # 后面的识别配置模块nvbc.yaml文件会用到
    • --input_channel:特征提取器的输入通道数,默认1
    • --output_channel:特征提取器的输出通道数,默认512
    • --hidden_size:LSTM 隐藏状态的大小,默认256
    报错:提示训练模型需在CUDA设备上运行
    但若想在CPU上运行,可根据提示修改为如下:
    再次运行,得到:
    等待一段时间,直至出现“end the training”字符,训练结束。
    学习结果保存在当前目录下的/saved_models 文件夹中:
    存储的学习结果信息如下:
    • best_accuracy.pth / best_norm_ED.pth:在经过训练的模型文件中具有特定性能指数的选定模型;
    • log_dataset.txt:用于训练的数据集信息;
    • log_train.txt:训练正在进行时的日志(与上面终端中显示的相同)
    • opt.txt:执行学习命令语法时设置的学习选项信息

    3.4、测试模型

    让我们使用训练好的模型best_accuracy.pth来检查训练是否正确完成。
    同样,上面使用的语法按原样使用。但是,要使用的模型被指定为新学习的模型(./saved_models/None-VGG-BiLSTM-CTC-Seed1111/best_accuracy.pth)。
    1. # 测试项目中包含的演示图像
    2. python3 demo.py \
    3. --Transformation None \
    4. --FeatureExtraction VGG \
    5. --SequenceModeling BiLSTM \
    6. --Prediction CTC \
    7. --image_folder demo_image/ \
    8. --saved_model ./saved_models/None-VGG-BiLSTM-CTC-Seed1111/best_accuracy.pth

    4. 使用模型

    前提:环境上已经安装easyocr。

    4.1、用户模型环境配置

    用户学习模型、模块和配置文件的名称必须统一,这里假设用户模型文件的名称设置为“nvbc”。
    1. 复制3.3节生成的用户模型./saved_models/None-VGG-BiLSTM-CTC-Seed1111/best_accuracy.pth到/root/.EasyOCR/model/,改名为nvbc.pth;
    2. 在/root/.EasyOCR/user_network/下建立用户识别模型网络模块nvbc.py,用户识别配置模块nvbc.yaml。

    4.1.1、创建nvbc.yaml

    该配置文件包含用于训练学习模型的参数和使用EasyOCR模块所需的参数信息。
    # 值要与deep-text-recognition-benchmark/train.py中的值保持一致,因为是根据train.py训练出来的模型
     
    1. network_params:
    2.   input_channel: 1
    3.   output_channel: 512
    4.   hidden_size: 256
    5. imgH: 32
    6. lang_list:
    7.          - 'nvbc'   # 语言代码   对应与/usr/local/lib/python3.6/dist-packages/easyocr/character/nvbc_char.txt,没有则创建
    8. character_list: 0123456789abcdefghijklmnopqrstuvwxyz   # 学习数据类

    4.1.2、创建nvbc.py

    定义用户识别模型网络结构的模块文件,由于我们使用了EasyOCR模块中使用的'TPS-ResNet-BiLSTM-Attn'结构,所以可以使用EasyOCR项目提供的文件进行如下配置:
    1. import torch.nn as nn
    2. class Model(nn.Module):
    3.     def __init__(self, input_channel, output_channel, hidden_size, num_class):
    4.         super(Model, self).__init__()
    5.         """ FeatureExtraction """
    6.         self.FeatureExtraction = VGG_FeatureExtractor(input_channel, output_channel)
    7.         self.FeatureExtraction_output = output_channel
    8.         self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1))
    9.         """ Sequence modeling"""
    10.         self.SequenceModeling = nn.Sequential(
    11.             BidirectionalLSTM(self.FeatureExtraction_output, hidden_size, hidden_size),
    12.             BidirectionalLSTM(hidden_size, hidden_size, hidden_size))
    13.         self.SequenceModeling_output = hidden_size
    14.         """ Prediction """
    15.         self.Prediction = nn.Linear(self.SequenceModeling_output, num_class)
    16.     def forward(self, input, text):
    17.         """ Feature extraction stage """
    18.         visual_feature = self.FeatureExtraction(input)
    19.         visual_feature = self.AdaptiveAvgPool(visual_feature.permute(0, 3, 1, 2))
    20.         visual_feature = visual_feature.squeeze(3)
    21.         """ Sequence modeling stage """
    22.         contextual_feature = self.SequenceModeling(visual_feature)
    23.         """ Prediction stage """
    24.         prediction = self.Prediction(contextual_feature.contiguous())
    25.         return prediction
    26. class BidirectionalLSTM(nn.Module):
    27.     def __init__(self, input_size, hidden_size, output_size):
    28.         super(BidirectionalLSTM, self).__init__()
    29.         self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True)
    30.         self.linear = nn.Linear(hidden_size * 2, output_size)
    31.     def forward(self, input):
    32.         """
    33.         input : visual feature [batch_size x T x input_size]
    34.         output : contextual feature [batch_size x T x output_size]
    35.         """
    36.         try: # multi gpu needs this
    37.             self.rnn.flatten_parameters()
    38.         except: # quantization doesn't work with this
    39.             pass
    40.         recurrent, _ = self.rnn(input)  # batch_size x T x input_size -> batch_size x T x (2*hidden_size)
    41.         output = self.linear(recurrent)  # batch_size x T x output_size
    42.         return output
    43. class VGG_FeatureExtractor(nn.Module):
    44.     def __init__(self, input_channel, output_channel=256):
    45.         super(VGG_FeatureExtractor, self).__init__()
    46.         self.output_channel = [int(output_channel / 8), int(output_channel / 4),
    47.                                int(output_channel / 2), output_channel]
    48.         self.ConvNet = nn.Sequential(
    49.             nn.Conv2d(input_channel, self.output_channel[0], 3, 1, 1), nn.ReLU(True),
    50.             nn.MaxPool2d(2, 2),
    51.             nn.Conv2d(self.output_channel[0], self.output_channel[1], 3, 1, 1), nn.ReLU(True),
    52.             nn.MaxPool2d(2, 2),
    53.             nn.Conv2d(self.output_channel[1], self.output_channel[2], 3, 1, 1), nn.ReLU(True),
    54.             nn.Conv2d(self.output_channel[2], self.output_channel[2], 3, 1, 1), nn.ReLU(True),
    55.             nn.MaxPool2d((2, 1), (2, 1)),
    56.             nn.Conv2d(self.output_channel[2], self.output_channel[3], 3, 1, 1, bias=False),
    57.             nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True),
    58.             nn.Conv2d(self.output_channel[3], self.output_channel[3], 3, 1, 1, bias=False),
    59.             nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True),
    60.             nn.MaxPool2d((2, 1), (2, 1)),
    61.             nn.Conv2d(self.output_channel[3], self.output_channel[3], 2, 1, 0), nn.ReLU(True))
    62.     def forward(self, input):
    63.         return self.ConvNet(input)
    作为参考,如果你想通过改变模型的网络结构来学习和使用,deep-text-recognition-benchmark项目的'deep-text-recognition-benchmark/model.py'文件和'deep-text -recognition-benchmark/modules/ 你可以参考'.custom.py'中的文件来配置这个'custom.py'文件。

    4.2、EasyOCR 运行参数

    参考: OCR-easyocr初识_青霄的博客-CSDN博客
    编写如下代码并运行它:testzq.py
    1. from easyocr.easyocr import *
    2. # # GPU 环境
    3. # os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
    4. def get_files(path):
    5.     files = [f for f in os.listdir(path) if not f.startswith('.')]  # skip hidden file
    6.     files.sort()
    7.     abspath = os.path.abspath(path)
    8.     file_list = []
    9.     for file in files:
    10.         file_path = os.path.join(abspath, file)
    11.         file_list.append(file_path)
    12.     return file_list, len(file_list)
    13. if __name__ == '__main__':
    14.     # Using custom model
    15.     reader = Reader(['nvbc'], gpu=False,   # 语言存储在/usr/local/lib/python3.6/dist-packages/easyocr/character/nvbc_char.txt
    16.                     model_storage_directory='/root/.EasyOCR/model',  
    17.                     user_network_directory='/root/.EasyOCR/user_network',
    18.                     recog_network='nvbc')
    19.     files, count = get_files(path='/home/deep-text-recognition-benchmark/demo_image/')
    20.     for idx, file in enumerate(files):
    21.         filename = os.path.basename(file)
    22.         result = reader.readtext(file)
    23.         # ./easyocr/utils.py 733 lines
    24.         # result[0]: bbox
    25.         # result[1]: string
    26.         # result[2]: confidence
    27.         for (bbox, string, confidence) in result:
    28.             print("filename: '%s', confidence: %.4f, string: '%s'" % (filename, confidence, string))

    使用用户模型运行: python3 testzq.py,结果如下:

    ​

    错误1:训练数据比较大时,训练模型报错:ValueError: num_samples should be a positive integer value, but got num_samples=0
    ​
    原因是:图片的名称长度大于--batch_max_length的默认值、而且包含的字符不在默认的--character中
    ​

    五、参考

    Easy-OCR笔记整理 - 知乎如果可以在瑞士工作一年的话,我会享受这样的生活。 学习EasyOCR用户模型这次通过EasyOCR提供的API,不是使用OCR功能时使用的基本神经网络模型,而是直接准备和学习用户想要学习的数据,并创建和使用具有所需性能…icon-default.png?t=N7T8https://zhuanlan.zhihu.com/p/400270506

    【扫盲】RCNN+CTC字符训练识别_哔哩哔哩_bilibili欢迎关注公众号:小鸡炖技术 ,后台回复:“RCNN+CTC”获取本教程素材~~~, 视频播放量 3161、弹幕量 0、点赞数 38、投硬币枚数 34、收藏人数 103、转发人数 5, 视频作者 小鸡炖技术, 作者简介 公众号:小鸡炖技术,相关视频:【陈巍学基因】视频30:CellSearch检测CTC,见微知著的查癌方法——CTC循环肿瘤细胞检测,字符识别,时间序列LSTM深度学习模型代码讲解,1.1Faster RCNN理论合集,6、字符分割,如何读懂PyTorch深度学习代码-第一个深度学习实例-手写字符识别代码解析,【扫盲】DarkNet下YoloV4训练,阿丘科技深度学习AIDI讲解之字符识别,美国铁路CTC调度集中系统 - BNSF铁路官方科普【搬运】icon-default.png?t=N7T8https://www.bilibili.com/video/BV1JA411t7H9 deep-text-recognition-benchmarkicon-default.png?t=N7T8https://link.zhihu.com/?target=https%3A//github.com/clovaai/deep-text-recognition-benchmark

    python - Size mismatch for fc.bias and fc.weight in PyTorch - Stack Overflowicon-default.png?t=N7T8https://stackoverflow.com/questions/53612835/size-mismatch-for-fc-bias-and-fc-weight-in-pytorch

    PyTorch加载模型出现Error(s) in loading state_dict() for Model问题,Unexpected key(s) in state_dict: “...“_行走的笔记的博客-CSDN博客问题:模型在训练过程中可以正常训练,但是测试的时候出现了错误,如下所示:RuntimeError: Error(s) in loading state_dict for ModuleList:Missing key(s) in state_dict: "0.weight", "1.weight", "1.bias", "1.running_mean", "1.running_var", Unexpected key(s) in state_dict: "conv1.weight", "bn1.wehttps://blog.csdn.net/qq_45777045/article/details/109481993 Ubuntu 18.04 安装 NVIDIA 显卡驱动 - 知乎我们今天的目标是在 Ubuntu 18.04 上安装 NVIDIA 显卡驱动,请注意,你的显卡一定要是 NVIDIA 的显卡才能按照这篇文章的方法安装。我将给大家介绍三种安装方法,建议使用第一种方法安装。 先来说说带有 NVIDIA 独…icon-default.png?t=N7T8https://zhuanlan.zhihu.com/p/59618999

  • 相关阅读:
    前端 CSS 经典:SVG 描边动画
    Android Handler/Looper视角看UI线程的原理
    遥感和随机森林核心思想python
    【C++】泛型编程 ⑦ ( 类模板常用用法 | 类模板声明 | 类模板调用 | 类模板作为函数参数 )
    软件明明通过了各种级别的测试,交付给用户仍会出现问题?
    vim相关命令讲解!
    B. Bin Packing Problem(线段树+multiset)
    为什么创建百科词条?百科营销的作用
    Spring中@Validated和@Valid区别是什么
    多功能神器,强劲升级,太极2.x你值得拥有!
  • 原文地址:https://blog.csdn.net/leiwuhen92/article/details/126419345
  • 最新文章
  • 攻防演习之三天拿下官网站群
    数据安全治理学习——前期安全规划和安全管理体系建设
    企业安全 | 企业内一次钓鱼演练准备过程
    内网渗透测试 | Kerberos协议及其部分攻击手法
    0day的产生 | 不懂代码的"代码审计"
    安装scrcpy-client模块av模块异常,环境问题解决方案
    leetcode hot100【LeetCode 279. 完全平方数】java实现
    OpenWrt下安装Mosquitto
    AnatoMask论文汇总
    【AI日记】24.11.01 LangChain、openai api和github copilot
  • 热门文章
  • 十款代码表白小特效 一个比一个浪漫 赶紧收藏起来吧!!!
    奉劝各位学弟学妹们,该打造你的技术影响力了!
    五年了,我在 CSDN 的两个一百万。
    Java俄罗斯方块,老程序员花了一个周末,连接中学年代!
    面试官都震惊,你这网络基础可以啊!
    你真的会用百度吗?我不信 — 那些不为人知的搜索引擎语法
    心情不好的时候,用 Python 画棵樱花树送给自己吧
    通宵一晚做出来的一款类似CS的第一人称射击游戏Demo!原来做游戏也不是很难,连憨憨学妹都学会了!
    13 万字 C 语言从入门到精通保姆级教程2021 年版
    10行代码集2000张美女图,Python爬虫120例,再上征途
Copyright © 2022 侵权请联系2656653265@qq.com    京ICP备2022015340号-1
正则表达式工具 cron表达式工具 密码生成工具

京公网安备 11010502049817号