下载地址:https://gitee.com/paddlepaddle/PaddleOCR
我要训练一个中文模型,看到该预训练模型泛化性能最优,于是下载这个模型
https://gitee.com/link?target=https%3A%2F%2Fpaddleocr.bj.bcebos.com%2FPP-OCRv3%2Fchinese%2Fch_PP-OCRv3_rec_train.tar
其他模型地址:https://gitee.com/paddlepaddle/PaddleOCR/blob/release/2.6/doc/doc_ch/models_list.md
config文件夹用来装yml配置文件
pretrained_model用来装上一步下载的预训练模型
split_rec_label用来放数据集
output用来放训练出的模型
创建文件夹非强制,只是这样更方便管理自己文件,yml源文件地址就在
PaddleOCR-release-2.6/configs/rec/PP-OCRv3这个路径下
参考博客:https://blog.csdn.net/qq_49627063/article/details/119134847
在训练之前,所有图片都在一个文件夹中,所有label信息都在同一个txt文件中,因此需要编写脚本,将其按照8:1:1的比例进行分割。
import os
import re
import shutil
import random
import argparse
def split_label(all_label, train_label, val_label, test_label):
f = open(all_label, 'r')
f_train = open(train_label, 'w')
f_val = open(val_label, 'w')
f_test = open(test_label, 'w')
raw_list = f.readlines()
num_train = int(len(raw_list) * 0.8)
num_val = int(len(raw_list) * 0.1)
num_test = int(len(raw_list) * 0.1)
random.shuffle(raw_list)
for i in range(num_train):
f_train.writelines(raw_list[i])
for i in range(num_train, num_train + num_val):
f_val.writelines(raw_list[i])
for i in range(num_train + num_val, num_train + num_val + num_test):
f_test.writelines(raw_list[i])
f.close()
f_train.close()
f_val.close()
f_test.close()
def split_img(all_imgs, train_label, train_imgs, val_label, val_imgs, test_label, test_imgs):
f_train = open(train_label, 'r')
f_val = open(val_label, 'r')
f_test = open(test_label, 'r')
train_list = f_train.readlines()
val_list = f_val.readlines()
test_list = f_test.readlines()
for i in range(len(train_list)):
img_path = os.path.join(all_imgs, re.split("[/\t]", train_list[i])[1])
shutil.move(img_path, train_imgs)
for i in range(len(val_list)):
img_path = os.path.join(all_imgs, re.split("[/\t]", val_list[i])[1])
shutil.move(img_path, val_imgs)
for i in range(len(test_list)):
img_path = os.path.join(all_imgs, re.split("[/\t]", test_list[i])[1])
shutil.move(img_path, test_imgs)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--all_label", default="../paddleocr/PaddleOCR/train_data/cls/cls_gt_train.txt")
parser.add_argument("--all_imgs_dir", default="../paddleocr/PaddleOCR/train_data/cls/images/")
parser.add_argument("--train_label", default="../paddleocr/PaddleOCR/train_data/cls/train.txt")
parser.add_argument("--train_imgs_dir", default="../paddleocr/PaddleOCR/train_data/cls/train/")
parser.add_argument("--val_label", default="../paddleocr/PaddleOCR/train_data/cls/val.txt")
parser.add_argument("--val_imgs_dir", default="../paddleocr/PaddleOCR/train_data/cls/val/")
parser.add_argument("--test_label", default="../paddleocr/PaddleOCR/train_data/cls/test.txt")
parser.add_argument("--test_imgs_dir", default="../paddleocr/PaddleOCR/train_data/cls/test/")
return parser.parse_args()
def main(args):
if not os.path.isdir(args.train_imgs_dir):
os.makedirs(args.train_imgs_dir)
if not os.path.isdir(args.val_imgs_dir):
os.makedirs(args.val_imgs_dir)
if not os.path.isdir(args.test_imgs_dir):
os.makedirs(args.test_imgs_dir)
split_label(args.all_label, args.train_label, args.val_label, args.test_label)
split_img(args.all_imgs_dir, args.train_label, args.train_imgs_dir, args.val_label, args.val_imgs_dir, args.test_label, args.test_imgs_dir)
if __name__ == "__main__":
main(get_args())
Global:
debug: false
use_gpu: true
epoch_num: 800
log_smooth_window: 20
print_batch_step: 10
save_model_dir: wjp/output/rec_ppocr_v3_distillation
save_epoch_step: 3
eval_batch_step: [0, 2000]
cal_metric_during_train: true
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: false
infer_img: doc/imgs_words/ch/word_1.jpg
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
max_text_length: &max_text_length 25
infer_mode: false
use_space_char: true
distributed: true
save_res_path: wjp/output/rec/predicts_ppocrv3_distillation.txt
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Piecewise
decay_epochs : [700]
values : [0.0005, 0.00005]
warmup_epoch: 5
regularizer:
name: L2
factor: 3.0e-05
Architecture:
model_type: &model_type "rec"
name: DistillationModel
algorithm: Distillation
Models:
Teacher:
pretrained:
freeze_params: false
return_all_feats: true
model_type: *model_type
algorithm: SVTR
Transform:
Backbone:
name: MobileNetV1Enhance
scale: 0.5
last_conv_stride: [1, 2]
last_pool_type: avg
Head:
name: MultiHead
head_list:
- CTCHead:
Neck:
name: svtr
dims: 64
depth: 2
hidden_dims: 120
use_guide: True
Head:
fc_decay: 0.00001
- SARHead:
enc_dim: 512
max_text_length: *max_text_length
Student:
pretrained:
freeze_params: false
return_all_feats: true
model_type: *model_type
algorithm: SVTR
Transform:
Backbone:
name: MobileNetV1Enhance
scale: 0.5
last_conv_stride: [1, 2]
last_pool_type: avg
Head:
name: MultiHead
head_list:
- CTCHead:
Neck:
name: svtr
dims: 64
depth: 2
hidden_dims: 120
use_guide: True
Head:
fc_decay: 0.00001
- SARHead:
enc_dim: 512
max_text_length: *max_text_length
Loss:
name: CombinedLoss
loss_config_list:
- DistillationDMLLoss:
weight: 1.0
act: "softmax"
use_log: true
model_name_pairs:
- ["Student", "Teacher"]
key: head_out
multi_head: True
dis_head: ctc
name: dml_ctc
- DistillationDMLLoss:
weight: 0.5
act: "softmax"
use_log: true
model_name_pairs:
- ["Student", "Teacher"]
key: head_out
multi_head: True
dis_head: sar
name: dml_sar
- DistillationDistanceLoss:
weight: 1.0
mode: "l2"
model_name_pairs:
- ["Student", "Teacher"]
key: backbone_out
- DistillationCTCLoss:
weight: 1.0
model_name_list: ["Student", "Teacher"]
key: head_out
multi_head: True
- DistillationSARLoss:
weight: 1.0
model_name_list: ["Student", "Teacher"]
key: head_out
multi_head: True
PostProcess:
name: DistillationCTCLabelDecode
model_name: ["Student", "Teacher"]
key: head_out
multi_head: True
Metric:
name: DistillationMetric
base_metric_name: RecMetric
main_indicator: acc
key: "Student"
ignore_space: False
Train:
dataset:
name: SimpleDataSet
data_dir: wjp/split_rec_label/train
ext_op_transform_idx: 1
label_file_list:
- wjp/split_rec_label/train.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- RecConAug:
prob: 0.5
ext_data_num: 2
image_shape: [48, 320, 3]
max_text_length: *max_text_length
- RecAug:
- MultiLabelEncode:
- RecResizeImg:
image_shape: [3, 48, 320]
- KeepKeys:
keep_keys:
- image
- label_ctc
- label_sar
- length
- valid_ratio
loader:
shuffle: true
batch_size_per_card: 32
drop_last: true
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: wjp/split_rec_label/val
label_file_list:
- wjp/split_rec_label/val.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- MultiLabelEncode:
- RecResizeImg:
image_shape: [3, 48, 320]
- KeepKeys:
keep_keys:
- image
- label_ctc
- label_sar
- length
- valid_ratio
loader:
shuffle: false
drop_last: false
batch_size_per_card: 128
num_workers: 4
参考博客:https://blog.csdn.net/wyf2017/article/details/118676765
conda create -n ppocr
source activate ppocr
python tools/train.py -c wjp/ch_PP-OCRv3_rec_distillation.yml -o Global.pretrained_model=wjp/ch_PP-OCRv3_rec_train/best_accuracy
//-c参数放配置文件地址,-o参数放预训练模型地址
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple
ImportError: libGL.so.1: cannot open shared object file: No such file or directory
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple opencv-python-headless
ModuleNotFoundError: No module named 'Polygon'
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple Polygon3
ModuleNotFoundError: No module named 'lanms'
源码下载地址:https://github.com/AndranikSargsyan/lanms-nova/tree/master
参考我这个教程编译:http://t.csdnimg.cn/BqOW6
import numpy as np
def merge_quadrangle_n9(polys, thres=0.3, precision=10000):
if len(polys) == 0:
return np.array([], dtype='float32')
p = polys.copy()
p[:, :8] *= precision
ret = np.array(merge_quadrangle_n9(p, thres), dtype='float32')
ret[:, :8] /= precision
return ret
pip show numpy
就知道该环境下的包安装地址
f = open('txt01.txt',encoding='utf-8')
将 encoding=’utf-8’ 改为GB2312、gbk、ISO-8859-1,随便尝试一个均可以
将训练的配置yml文件中的batch_size_per_card参数不断改小(除以2),直到不再报这个错即可。