• 阿里天池街景字符编码YOLO5方案


    前言

    最近在做OCR相关的任务,用到了阿里天池一个街景字符识别比赛的数据集,索性就分享一下相关方案,我采用YOLO5模型,最终在平台提交分数也做到了0.924,没有经过任何优化,可以看出YOLO5的效果还是非常不错的。

    比赛地址链接:https://tianchi.aliyun.com/competition/entrance/531795/introduction?spm=5176.12281973.1005.7.3dd52448VtZc6t

    下载YOLO5模型

    YOLO5下载:https://github.com/ultralytics/yolov5
    在这里插入图片描述
    下载压缩包,然后放到自己文件夹进行解压。

    在yolo5-master中打开命令行,键入以下命令安装相关包:

    pip install -r requirements.txt
    
    • 1

    注意:安装包的时候可能会有各种各样的报错,特别是安装pycocotools的时候,不用慌,把报错复制粘贴到百度上面,都能解决!

    获取数据集

    YOLO已经准备好了,现在把比赛数据集拿出来,解析数据我就不自己写了,直接采用另外一位论坛上面老哥的代码,这里是训练集的处理,验证集也是一样的:

    import os
    import cv2
    import json
    train_json = json.load(open('mchar_train.json'))
    for x in train_json:
        img=cv2.imread("images/train/"+x)
        width=img.shape[1]
        height=img.shape[0]
        train_label =list(map(int,train_json[x]['label']))
        train_height=list(map(int,train_json[x]['height']))
        train_left=list(map(int,train_json[x]['left']))
        train_width=list(map(int,train_json[x]['width']))
        train_top=list(map(int,train_json[x]['top']))
        loc_pic="labels/train/"+x.split('.')[0]+'.txt' 
        pic=open(loc_pic,"w")
        for i in range(len(train_label)):
            pic_label=train_label[i]
            pic_x=(train_left[i]+train_width[i]/2)/width
            pic_y=(train_top[i]+train_height[i]/2)/height
            pic_width=train_width[i]/width
            pic_height=train_height[i]/height            
            pic.write(str(pic_label)+" "+str(pic_x)+" "+str(pic_y)+" "+str(pic_width)+" "+str(pic_height))
            pic.write("\n")
        pic.close()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24

    解析后的数据都是txt格式的,因为YOLO模型输入格式要求也是这样
    在这里插入图片描述
    数据处理好了以后,我们在yolo5-master中创建一个名为tianchi的文件夹,文件夹结构如下:
    在这里插入图片描述
    文件夹创建好后,把对应的数据拷贝进相应文件夹中就行了。

    模型训练

    我们把models文件夹中yolo5s.yaml文件复制一份到tianchi文件夹中,同时把data文件夹中coco128.yaml文件也复制一份到tianchi文件夹中,并且把yolo5s.yaml改名为street_yolo5s.yaml,把coco128.yaml改名为street_yolo.yaml,改好后如下图:
    在这里插入图片描述
    然后我们再将这两个文件中的内容进行修改,首先修改street_yolo5s.yaml:将nc的值改为10
    在这里插入图片描述
    然后修改street_yolo.yaml文件,只需要修改train和val的路径,还有nc和names就行了,然后把path那一行注释掉,修改后如下:
    在这里插入图片描述
    改完之后我们就可以进行模型训练了!!!
    在yolo5-master中打开命令行,执行以下命令(这里我只设置了20个epoch作为示例,我自己是100个epoch训练后才是0.924):

    python train.py --data tianchi/streat_yolo.yaml --cfg tianchi/street_yolo5s.yaml --epochs 100
    
    • 1

    在这里插入图片描述
    训练会花费很多时间,我训练了21个小时,电脑太垃圾了(GTX1050)!

    测试数据预测

    将test图片数据放入images文件夹中:
    在这里插入图片描述
    然后执行如下命令即可:

    python detect.py --weights runs/train/exp/weights/best.pt --source  tianchi/images/test/ --save-txt
    
    • 1

    预测完成后在runs/detect/exp中可以可以看到训练后的结果:
    在这里插入图片描述

    提交结果

    预测出来的labels格式不是最终提交结果,我们要按照比赛要求的提交结果来,所以还要对结果进行一点处理:

    import pandas as pd
    import glob
    import os
    def get(elem):
        return elem[1]
    label_path=glob.glob('labels/*.txt')
    label_path.sort()
    df_submit = pd.read_csv('mchar_sample_submit_A.csv')
    df_submit.set_index('file_name')
    for x in label_path:
        text=open(x,'r')
        result_list=[]
        for line in text.readlines():
            result_list.append((line.split(' ')[0],line.split(' ')[1]))
        result_list.sort(key=get)
        result=''
        for j in result_list:
            result+=j[0]
        label_path=x.split('\\')[-1].split('.')[0]+'.png'
        df_submit['file_code'][df_submit['file_name']==label_path]=result
        text.close()
    df_submit.to_csv('content/submit.csv', index=None)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22

    将submit.csv提交到天池平台上面:
    在这里插入图片描述

    写在最后

    上面的成绩只是单纯的调用了模型,没有进行任何调优和融合,可以看出YOLO5的效果还是很好的,我们也可以采用不同的YOLO权重参数进行训练尝试,效果会更好。同时也感谢论坛各位大佬提供的各种想法和代码,我也是学到了很多,本人才疏学浅,如果有不对的地方希望指正!

  • 相关阅读:
    Alter database open fails with ORA-00600 kcratr_nab_less_than_odr
    io.lettuce.core.RedisCommandTimeoutException: Command timed out 解决办法
    【Verilog基础】8.加法器
    js给一段话,遇到的第一个括号处加上换行符
    【目标检测】数据增强:YOLO官方数据增强实现/imgaug的简单使用
    实现 企业微信认证 网络准入认证 配置
    [BJDCTF2020]EasySearch Apache SSI漏洞
    lambda表达式在实际开发中的使用
    从Jenkins中获取Maven项目版本号
    LLVM学习入门(1):Kaleidoscope语言及词法分析
  • 原文地址:https://blog.csdn.net/qq_44694861/article/details/124523492