• Paddle-OCR根据垂直类场景自定义数据微调PP-OCRv4模型


    Paddle-OCR根据垂直类场景自定义数据微调PP-OCRv4模型

    1 文本检测模型微调

    数据准备:

    • 加入少量真实数据(检测任务>=500张, 识别任务>=5000张),会大幅提升垂类场景的检测与识别效果
    • 在模型微调时,加入真实通用场景数据,可以进一步提升模型精度与泛化性能
    • 在图像检测任务中,增大图像的预测尺度,能够进一步提升较小文字区域的检测效果
    • 在模型微调时,需要适当调整超参数(学习率,batch size最为重要),以获得更优的微调效果。
    • 数据标注:单行文本标注格式,建议标注的检测框与实际语义内容一致。

    1-1 数据准备

    训练集&校验集

    PaddleOCR 中的文本检测算法支持的标注文件格式如下,中间用"\t"分隔:

    " 图像文件名                    json.dumps编码的图像标注信息"
    ch4_test_images/img_61.jpg    [{"transcription": "MASA", "points": [[310, 104], [416, 141], [418, 216], [312, 179]]}, {...}]
    
    • 1
    • 2

    json.dumps编码前的图像标注信息是包含多个字典的list,字典中的 points 表示文本框的四个点的坐标(x, y),从左上角的点开始顺时针排列。 transcription 表示当前文本框的文字,当其内容为“###”时,表示该文本框无效,在训练时会跳过。

    公开数据集
    数据集名称图片下载地址PaddleOCR 标注下载地址
    ICDAR 2015https://rrc.cvc.uab.es/?ch=4&com=downloadstrain / test
    ctw1500https://paddleocr.bj.bcebos.com/dataset/ctw1500.zip图片下载地址中已包含
    total texthttps://paddleocr.bj.bcebos.com/dataset/total_text.tar图片下载地址中已包含
    td trhttps://paddleocr.bj.bcebos.com/dataset/TD_TR.tar图片下载地址中已包含

    1-2 下载预训练模型

    ch_PP-OCRv4_det_train

    1-3 参数配置

    配置文件: configs/det/ch_PP-OCRv4/ch_PP-OCRv4_det_student.yml

    Global:
      debug: false
      use_gpu: true
      epoch_num: &epoch_num 500
      log_smooth_window: 20
      print_batch_step: 100
      save_model_dir: ./output/ch_PP-OCRv4
      save_epoch_step: 10
      eval_batch_step:
      - 0
      - 1500
      cal_metric_during_train: false
      checkpoints:
      pretrained_model: https://paddleocr.bj.bcebos.com/pretrained/PPLCNetV3_x0_75_ocr_det.pdparams
      save_inference_dir: null
      use_visualdl: false
      infer_img: doc/imgs_en/img_10.jpg
      save_res_path: ./checkpoints/det_db/predicts_db.txt
      distributed: true
    
    Architecture:
      model_type: det
      algorithm: DB
      Transform: null
      Backbone:
        name: PPLCNetV3
        scale: 0.75
        det: True
      Neck:
        name: RSEFPN
        out_channels: 96
        shortcut: True
      Head:
        name: DBHead
        k: 50
    
    Loss:
      name: DBLoss
      balance_loss: true
      main_loss_type: DiceLoss
      alpha: 5
      beta: 10
      ohem_ratio: 3
    
    Optimizer:
      name: Adam
      beta1: 0.9
      beta2: 0.999
      lr:
        name: Cosine
        learning_rate: 0.001 #(8*8c)
        warmup_epoch: 2
      regularizer:
        name: L2
        factor: 5.0e-05
    
    PostProcess:
      name: DBPostProcess
      thresh: 0.3
      box_thresh: 0.6
      max_candidates: 1000
      unclip_ratio: 1.5
    
    Metric:
      name: DetMetric
      main_indicator: hmean
    
    Train:
      dataset:
        name: SimpleDataSet
        data_dir: ./train_data/icdar2015/text_localization/
        label_file_list:
          - ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
        ratio_list: [1.0]
        transforms:
        - DecodeImage:
            img_mode: BGR
            channel_first: false
        - DetLabelEncode: null
        - CopyPaste: null
        - IaaAugment:
            augmenter_args:
            - type: Fliplr
              args:
                p: 0.5
            - type: Affine
              args:
                rotate:
                - -10
                - 10
            - type: Resize
              args:
                size:
                - 0.5
                - 3
        - EastRandomCropData:
            size:
            - 640
            - 640
            max_tries: 50
            keep_ratio: true
        - MakeBorderMap:
            shrink_ratio: 0.4
            thresh_min: 0.3
            thresh_max: 0.7
            total_epoch: *epoch_num
        - MakeShrinkMap:
            shrink_ratio: 0.4
            min_text_size: 8
            total_epoch: *epoch_num
        - NormalizeImage:
            scale: 1./255.
            mean:
            - 0.485
            - 0.456
            - 0.406
            std:
            - 0.229
            - 0.224
            - 0.225
            order: hwc
        - ToCHWImage: null
        - KeepKeys:
            keep_keys:
            - image
            - threshold_map
            - threshold_mask
            - shrink_map
            - shrink_mask
      loader:
        shuffle: true
        drop_last: false
        batch_size_per_card: 8
        num_workers: 8
    
    Eval:
      dataset:
        name: SimpleDataSet
        data_dir: ./train_data/icdar2015/text_localization/
        label_file_list:
          - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
        transforms:
        - DecodeImage:
            img_mode: BGR
            channel_first: false
        - DetLabelEncode: null
        - DetResizeForTest:
        - NormalizeImage:
            scale: 1./255.
            mean:
            - 0.485
            - 0.456
            - 0.406
            std:
            - 0.229
            - 0.224
            - 0.225
            order: hwc
        - ToCHWImage: null
        - KeepKeys:
            keep_keys:
            - image
            - shape
            - polys
            - ignore_tags
      loader:
        shuffle: false
        drop_last: false
        batch_size_per_card: 1
        num_workers: 2
    profiler_options: null
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
    • 168
    • 169
    • 170
    • 171

    1-4 训练

    python tools/train.py -c configs/det/ch_PP-OCRv4/ch_PP-OCRv4_det_student.yml \
         -o Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained
    
    • 1
    • 2

    1-5 评估

    python tools/eval.py -c configs/det/ch_PP-OCRv4/ch_PP-OCRv4_det_student.yml -o Global.checkpoints="{path/to/weights}/best_accuracy"
    
    • 1

    1-6 推理

    python tools/infer_det.py -c configs/det/ch_PP-OCRv4/ch_PP-OCRv4_det_student.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.pretrained_model="./output/det_db/best_accuracy"
    
    • 1

    1-7 导出

    python3 tools/export_model.py -c configs/det/ch_PP-OCRv4/ch_PP-OCRv4_det_student.yml -o Global.pretrained_model="./output/det_db/best_accuracy" Global.save_inference_dir="./output/det_db_inference/"
    
    • 1

    2 文本识别模型微调

    2-1 数据准备

    训练集&校验集

    建议将训练图片放入同一个文件夹,并用一个txt文件(rec_gt_train.txt)记录图片路径和标签,txt文件里的内容如下:

    注意: txt文件中默认请将图片路径和图片标签用 \t 分割,如用其他方式分割将造成训练报错。

    " 图像文件名                 图像标注信息 "
    
    train_data/rec/train/word_001.jpg   简单可依赖
    train_data/rec/train/word_002.jpg   用科技让复杂的世界更简单
    ...
    
    • 1
    • 2
    • 3
    • 4
    • 5

    最终训练集应有如下文件结构:

    |-train_data
      |-rec
        |- rec_gt_train.txt
        |- train
            |- word_001.png
            |- word_002.jpg
            |- word_003.jpg
            | ...
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    除上述单张图像为一行格式之外,PaddleOCR也支持对离线增广后的数据进行训练,为了防止相同样本在同一个batch中被多次采样,我们可以将相同标签对应的图片路径写在一行中,以列表的形式给出,在训练中,PaddleOCR会随机选择列表中的一张图片进行训练。对应地,标注文件的格式如下。

    ["11.jpg", "12.jpg"]   简单可依赖
    ["21.jpg", "22.jpg", "23.jpg"]   用科技让复杂的世界更简单
    3.jpg   ocr
    
    • 1
    • 2
    • 3

    上述示例标注文件中,"11.jpg"和"12.jpg"的标签相同,都是简单可依赖,在训练的时候,对于该行标注,会随机选择其中的一张图片进行训练。

    如果有通用真实场景数据加进来,建议每个epoch中,垂类场景数据与真实场景的数据量保持在1:1左右。

    比如:您自己的垂类场景识别数据量为1W,数据标签文件为vertical.txt,收集到的通用场景识别数据量为10W,数据标签文件为general.txt

    那么,可以设置label_file_listratio_list参数如下所示。每个epoch中,vertical.txt中会进行全采样(采样比例为1.0),包含1W条数据;general.txt中会按照0.1的采样比例进行采样,包含10W*0.1=1W条数据,最终二者的比例为1:1

    Train:
      dataset:
        name: SimpleDataSet
        data_dir: ./train_data/
        label_file_list:
        - vertical.txt
        - general.txt
        ratio_list: [1.0, 0.1]
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    字典

    需要提供一个自定义字典({word_dict_name}.txt),使模型在训练时,可以将所有出现的字符映射为字典的索引。

    因此字典需要包含所有希望被正确识别的字符,{word_dict_name}.txt需要写成如下格式,并以 utf-8 编码格式保存:

    l
    d
    a
    d
    r
    n
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    word_dict.txt 每行有一个单字,将字符与数字索引映射在一起,“and” 将被映射成 [2 5 1]

    • 内置字典

      PaddleOCR内置了一部分字典,可以按需使用。

      ppocr/utils/ppocr_keys_v1.txt 是一个包含6623个字符的中文字典

      ppocr/utils/ic15_dict.txt 是一个包含36个字符的英文字典

      ppocr/utils/en_dict.txt 是一个包含96个字符的英文字典

    公开数据集
    数据集名称图片下载地址PaddleOCR 标注下载地址
    en benchmark(MJ, SJ, IIIT, SVT, IC03, IC13, IC15, SVTP, and CUTE.)DTRBLMDB格式,可直接用lmdb_dataset.py加载
    ICDAR 2015http://rrc.cvc.uab.es/?ch=4&com=downloadstrain/ test
    多语言数据集百度网盘 提取码:frgi google drive图片下载地址中已包含

    2-2 下载预训练模型

    ch_PP-OCRv4_rec_train

    2-3 参数配置

    配置文件:configs/rec/PP-OCRv4/ch_PP-OCRv4_rec.yml

    Global:
      debug: false
      use_gpu: true
      epoch_num: 200
      log_smooth_window: 20
      print_batch_step: 10
      save_model_dir: ./output/rec_ppocr_v4
      save_epoch_step: 10
      eval_batch_step: [0, 2000]
      cal_metric_during_train: true
      pretrained_model:
      checkpoints:
      save_inference_dir:
      use_visualdl: false
      infer_img: doc/imgs_words/ch/word_1.jpg
      character_dict_path: ppocr/utils/ppocr_keys_v1.txt
      max_text_length: &max_text_length 25
      infer_mode: false
      use_space_char: true
      distributed: true
      save_res_path: ./output/rec/predicts_ppocrv3.txt
    
    
    Optimizer:
      name: Adam
      beta1: 0.9
      beta2: 0.999
      lr:
        name: Cosine
        learning_rate: 0.001
        warmup_epoch: 5
      regularizer:
        name: L2
        factor: 3.0e-05
    
    
    Architecture:
      model_type: rec
      algorithm: SVTR_LCNet
      Transform:
      Backbone:
        name: PPLCNetV3
        scale: 0.95
      Head:
        name: MultiHead
        head_list:
          - CTCHead:
              Neck:
                name: svtr
                dims: 120
                depth: 2
                hidden_dims: 120
                kernel_size: [1, 3]
                use_guide: True
              Head:
                fc_decay: 0.00001
          - NRTRHead:
              nrtr_dim: 384
              max_text_length: *max_text_length
    
    Loss:
      name: MultiLoss
      loss_config_list:
        - CTCLoss:
        - NRTRLoss:
    
    PostProcess:  
      name: CTCLabelDecode
    
    Metric:
      name: RecMetric
      main_indicator: acc
    
    Train:
      dataset:
        name: MultiScaleDataSet
        ds_width: false
        data_dir: ./train_data/
        ext_op_transform_idx: 1
        label_file_list:
        - ./train_data/train_list.txt
        transforms:
        - DecodeImage:
            img_mode: BGR
            channel_first: false
        - RecConAug:
            prob: 0.5
            ext_data_num: 2
            image_shape: [48, 320, 3]
            max_text_length: *max_text_length
        - RecAug:
        - MultiLabelEncode:
            gtc_encode: NRTRLabelEncode
        - KeepKeys:
            keep_keys:
            - image
            - label_ctc
            - label_gtc
            - length
            - valid_ratio
      sampler:
        name: MultiScaleSampler
        scales: [[320, 32], [320, 48], [320, 64]]
        first_bs: &bs 192
        fix_bs: false
        divided_factor: [8, 16] # w, h
        is_training: True
      loader:
        shuffle: true
        batch_size_per_card: *bs
        drop_last: true
        num_workers: 8
    Eval:
      dataset:
        name: SimpleDataSet
        data_dir: ./train_data
        label_file_list:
        - ./train_data/val_list.txt
        transforms:
        - DecodeImage:
            img_mode: BGR
            channel_first: false
        - MultiLabelEncode:
            gtc_encode: NRTRLabelEncode
        - RecResizeImg:
            image_shape: [3, 48, 320]
        - KeepKeys:
            keep_keys:
            - image
            - label_ctc
            - label_gtc
            - length
            - valid_ratio
      loader:
        shuffle: false
        drop_last: false
        batch_size_per_card: 128
        num_workers: 4
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138

    2-4 训练

    python tools/train.py -c configs/rec/PP-OCRv4/ch_PP-OCRv4_rec.yml \
         -o Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained
    
    • 1
    • 2

    2-5 评估

    python tools/eval.py -c configs/rec/PP-OCRv4/ch_PP-OCRv4_rec.yml -o Global.checkpoints={path/to/weights}/best_accuracy
    
    • 1

    2-6 推理

    python tools/infer_rec.py -c configs/rec/PP-OCRv4/ch_PP-OCRv4_rec.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/ch/word_1.jpg
    
    • 1

    2-7 导出

    python tools/export_model.py -c configs/rec/PP-OCRv4/ch_PP-OCRv4_rec.yml -o Global.pretrained_model=./pretrain_models/en_PP-OCRv3_rec_train/best_accuracy  Global.save_inference_dir=./inference/en_PP-OCRv3_rec/
    
    • 1

    3 文本方向分类器微调

    3-1 数据准备

    训练集&校验集

    首先建议将训练图片放入同一个文件夹,并用一个txt文件(cls_gt_train.txt)记录图片路径和标签。

    注意: 默认请将图片路径和图片标签用 \t 分割,如用其他方式分割将造成训练报错

    0和180分别表示图片的角度为0度和180度

    " 图像文件名                 图像标注信息 "
    train/cls/train/word_001.jpg   0
    train/cls/train/word_002.jpg   180
    
    • 1
    • 2
    • 3

    最终训练集应有如下文件结构:

    |-train_data
        |-cls
            |- cls_gt_train.txt
            |- train
                |- word_001.png
                |- word_002.jpg
                |- word_003.jpg
                | ...
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    3-2 下载预训练模型

    ch_ppocr_mobile_v2.0_cls_train

    3-3 参数配置

    将准备好的txt文件和图片文件夹路径分别写入配置文件的 Train/Eval.dataset.label_file_listTrain/Eval.dataset.data_dir 字段下,Train/Eval.dataset.data_dir字段下的路径和文件里记载的图片名构成了图片的绝对路径。

    3-4 训练

    python tools/train.py -c configs/cls/cls_mv3.yml
    
    • 1

    3-5 评估

    python tools/eval.py -c configs/cls/cls_mv3.yml -o Global.checkpoints={path/to/weights}/best_accuracy
    
    • 1

    3-6 推理

    python tools/infer_cls.py -c configs/cls/cls_mv3.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.load_static_weights=false Global.infer_img=doc/imgs_words/ch/word_1.jpg
    
    • 1
  • 相关阅读:
    我的十年编程路 2016年篇
    丁鹿学堂:从零开始手写promise(二)
    分布式锁的实现【转载】
    Feign的面试
    智慧城市-疫情流调系列2.1-Prompt-UIE信息抽取,解决抽取结果不准的问题
    【分享】xpath的路径表达式
    【游戏编程扯淡精粹】工作两年总结
    Python 全栈系列209 so_pack
    Python 物联网之用于基于 TinyFlux的物联网和分析应用程序的微型时间序列数据库
    嵌入式分享合集91
  • 原文地址:https://blog.csdn.net/shanglianlm/article/details/134423832