• paddle ocr 训练数字识别模型


    选择识别算法

    gitlab
    在这里插入图片描述

    在这里插入图片描述

    修改配置文件

    复制rec_icdar15_train.yml配置文件,预训练模型rec_mv3_none_bilstm_ctc

    更改

    • pretrained_model预训练模型路径
    • character_dict_path字典路径,字典内容
    • data_dir
    • label_file_list
      在这里插入图片描述
    Global:
      use_gpu: true
      epoch_num: 10
      log_smooth_window: 20
      print_batch_step: 10000
      save_model_dir: ./output/rec/number/
      save_epoch_step: 5
      # evaluation is run every 2000 iterations
      eval_batch_step: [0, 3, 6, 9]
      cal_metric_during_train: True
      pretrained_model: pretrain_models/rec_mv3_none_bilstm_ctc_v2.0_train/best_accuracy
      checkpoints:
      save_inference_dir: ./
      use_visualdl: False
      infer_img: ./train_data/NUMBER/9997_448.jpg
      # for data or label process
      character_dict_path: ppocr/utils/number_dict.txt
      max_text_length: 6
      infer_mode: False
      use_space_char: True
      save_res_path: ./output/rec/predicts_number.txt
    
    Optimizer:
      name: Adam
      beta1: 0.9
      beta2: 0.999
      lr:
        learning_rate: 0.0005
      regularizer:
        name: 'L2'
        factor: 0
    
    Architecture:
      model_type: rec
      algorithm: CRNN
      Transform:
      Backbone:
        name: MobileNetV3
        scale: 0.5
        model_name: large
      Neck:
        name: SequenceEncoder
        encoder_type: rnn
        hidden_size: 96
      Head:
        name: CTCHead
        fc_decay: 0
    
    Loss:
      name: CTCLoss
    
    PostProcess:
      name: CTCLabelDecode
    
    Metric:
      name: RecMetric
      main_indicator: acc
    
    Train:
      dataset:
        name: SimpleDataSet
        data_dir: ./train_data/NUMBER/
        label_file_list: ["./train_data/NUMBER/rec_gt_train.txt"]
        transforms:
          - DecodeImage: # load image
              img_mode: BGR
              channel_first: False
          - CTCLabelEncode: # Class handling label
          - RecResizeImg:
              image_shape: [3, 32, 100]
          - KeepKeys:
              keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
      loader:
        shuffle: True
        batch_size_per_card: 256
        drop_last: True
        num_workers: 12
        use_shared_memory: False
    
    Eval:
      dataset:
        name: SimpleDataSet
        data_dir: ./train_data/NUMBER
        label_file_list: ["./train_data/NUMBER/rec_gt_test.txt"]
        transforms:
          - DecodeImage: # load image
              img_mode: BGR
              channel_first: False
          - CTCLabelEncode: # Class handling label
          - RecResizeImg:
              image_shape: [3, 32, 100]
          - KeepKeys:
              keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
      loader:
        shuffle: False
        drop_last: True
        batch_size_per_card: 256
        num_workers: 4
        use_shared_memory: False
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100

    文件格式

    simple dataset以\t分割
    在这里插入图片描述
    对图像生成标签

    import os
    import cv2
    
    from tqdm import tqdm
    
    img_folder = r'xxx'
    target_img_folder = r'./train_data' 
    img_file_list = os.listdir(img_folder)
    
    label_list = []
    
    
    def cv_show(img):
        '''
        展示图片
        @param img:
        @param name:
        @return:
        '''
        cv2.namedWindow('name', cv2.WINDOW_KEEPRATIO)  # cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO
        cv2.imshow('name', img)
        cv2.waitKey(0)
        cv2.destroyAllWindows()
    
    index = 1
    
    for file in tqdm(img_file_list):
        file_path = os.path.join(img_folder, file)
        start_page_str = str(index)
        if len(start_page_str) == 1:
            start_page_str = '00' + start_page_str
        elif len(start_page_str) == 2:
            start_page_str = '0' + start_page_str
        else:
            ...
        if file.endswith('jpg'):
            label = file.split('_')[-1].split('.')[0]
            new_file_path =  os.path.join(target_img_folder, str(start_page_str) +'_'+label+ '.jpg')
            os.rename(file_path,new_file_path)
    
            with open('./rec_gt_train.txt', 'a+', encoding='utf-8') as f:
                f.write(str(start_page_str) +'_'+label+ '.jpg'+'\t'+label+'\n')
            index += 1
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43

    dataset

    dataset:
        name: SimpleDataSet
        data_dir: ./train_data/NUMBER
        label_file_list: ["./train_data/NUMBER/xxx.txt"]
    
    • 1
    • 2
    • 3
    • 4

    在这里插入图片描述

    训练

    python tools/train.py -c configs/rec/rec_icdar15_number_train.yml

    有3万张图片,使用预训练模型,预训练模型效果:
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述

    在这里插入图片描述

    不使用训练模型,精度会差点
    在这里插入图片描述

    python tools/train.py -c configs/rec/rec_r34_vd_none_bilstm_ctc_number.yml

    转成推理模型

    python tools/export_model.py -c configs/rec/rec_icdar15_number_train.yml -o Global.checkpoints=./output/rec/number_mv3/best_accuracy  Global.save_inference_dir=./output/rec_icdar15_number/ 
    
    • 1

    警告:

    The shape of model params head.fc.weight [192, 12] not matched with loaded params head.fc.weight [192, 37]

    因为字典改了

    注意

    • 测试集和训练集size大于批次

    预测

    预测图片:
    在这里插入图片描述

    命令行:

    python ../../PaddleOCR/tools/infer/predict_rec.py --image_dir="./test_data/000_4.jpg" --rec_model_dir="../../PaddleOCR/output/recnumber_mv3_none_bilstm_ctc/" --rec_image_shape="3, 32, 100" --rec_char_dict_path="../../PaddleOCR/ppocr/utils/number_dict.txt"
    
    • 1

    在这里插入图片描述

    原始模型
    paddleocr --image_dir=“./test_data/000_4.jpg”
    在这里插入图片描述

    总结

    使用不同图像尺寸 效果会变差,训练中还需要做数据增强

  • 相关阅读:
    Maven的常用命令
    mysql查询结果拼接树结构(树节点的移动)
    AI的IDE:Cursor配置虚拟python环境(conda)
    Web基础习题
    Ubuntu 22.04 编译 DPDK 19.11 igb_uio 和 kni 报错解决办法
    30:第三章:开发通行证服务:13:开发【更改/完善用户信息,接口】;(使用***BO类承接参数,并使用了参数校验)
    Instagram Shop如何开通?如何销售?最全面攻略
    三星大规模生产3nm芯片?预计明年就能流通各大手机厂商手中
    《Datawhale项目实践系列》发布!
    [附源码]计算机毕业设计springboot汽配管理系统
  • 原文地址:https://blog.csdn.net/weixin_38235865/article/details/127849198