• CenterNet根据自己的数据训练模型


    本文参考:

    1、数据集相关的:https://blog.csdn.net/weixin_42634342/article/details/97697356

    2、训练自己的模型参考:

    https://bbs.huaweicloud.com/blogs/210374

    https://www.huaweicloud.com/articles/ebb05fa50237d7ac7ad6c0b29e38f969.html

    https://blog.csdn.net/jiangpeng59/article/details/105732166

    3、CenterNet官方安装文档:https://github.com/xingyizhou/CenterNet/blob/master/readme/INSTALL.md

    一、数据集处理

    1、训练集下载

    SeaShips数据集,链接为:http://www.lmars.whu.edu.cn/prof_web/shaozhenfeng/datasets/SeaShips(7000).zip,

    如果在linux上直接wget获取即可

    该版本数据集共有7000张图片,图片分辨率均为1920*1080,分为六类船只,主要是一些内河道中船只的图片。

    该数据集为PascalVOC数据集。

    2、voc数据集格式转coco格式

    centernet虽然同时支持coco和voc数据集,但是本次我们需要转成coco格式的数据集,方便后续进行操作

    转换的参考代码为:https://blog.csdn.net/yang332233/article/details/97205112

    只需要修改最后几行加粗的位置相关的代码即可。

    转换完毕之后的json文件后续会放到centernet对应的目录下面

    二、CenterNet代码编译

    1、安装python3.6环境

    2、安装pytorch 1.x以及pytorchvision 0.x版本

    3、安装cocoapi

    # COCOAPI=/path/to/clone/cocoapi

    git clone https://github.com/cocodataset/cocoapi.git $COCOAPI

    cd $COCOAPI/PythonAPI

    make

    python setup.py install --user

    4、下载CenterNet

    CenterNet_ROOT=/path/to/clone/CenterNet

    git clone https://github.com/xingyizhou/CenterNet $CenterNet_ROOT

    5、安装CenterNet依赖的python包

    pip install -r requirements.txt

    6、编译DCNv2

    CenterNet自带的DCNv2只支持pytorch0.4,随意会导致后续编译不成功,所以需要删除$CenterNet_ROOT/src/lib/models/networks/DCNv2目录,然后重新下载最新的DCNv2代码进行编译。

    对应的git地址是:git clone https://github.com/CharlesShang/DCNv2.git

    下载完成后进行编译: 

    ./make.sh

    7、编译NMS组件

    cd $CenterNet_ROOT/src/lib/external

    make

    三、使用已有的线上模型进行预测

    1、下载训练好的模型

    模型下载地址见:https://github.com/xingyizhou/CenterNet/blob/master/readme/MODEL_ZOO.md,

    如果做目标检测,可以下载下图所指的模型,这个模型大概是77M。

    模型下载后放到models目录下即可。

     

    2、修改结果输出方式

    centernet默认是把预测结果图片输出到screen,但是docker中无法显示导致报错,所以需要修改输出方式。

    修改 src/lib/detectors/cdet.py。

    将debugger.show_all_imgs(pause=self.pause) 注释掉,

    换成debugger.save_all_imgs(path='/home/jhsu/sujh/ljj/CenterNet/output', genID=True)

    如下图所示:

     

    3、图片进行目标检测

    随便网上找一张图片,执行以下命令:

    python src/demo.py ctdet --demo images/17790319373_bd19b24cfc_k.jpg --load_model models/ctdet_coco_dla_2x.pth

    运行出错可参考:http://blog.sina.com.cn/s/blog_628cc2b70102ysyi.html,相关问题可以参考进行解决

    执行完毕后会在output目录下输出结果,一般是xctdet.png的文件。得到的结果为:

     

    四、使用已有的数据集进行训练

    1、存放数据集

    在CenterNet主目录的data目录下创建MyDataTest,如下图所示:


    annotations目录中存放第一步的json文件,比如我只有train.json文件

    images目录存放7000张jpg文件

    2、在CenterNet-master/src/lib/datasets/dataset/文件夹里面,复制coco.py并从命名为my_test.py

    打开my_test.py修改:

    line13:class COCO修改成class my_test

    line14:num_classes = 6 #注意这里不包含背景类,只有6种船的类型

    line15:default_resolution = [512, 512] 修改自己需要的训练图片大小,虽然我们的图片是1920*1080,但是无需修改

    line16,18:均值方差改自己的,或者也可以不改

    Line22:super(COCO, self).init()里面的COCO换成自己的类名my_test

    Line23,24:修改自己的数据路径

     
    line26-37:修改自己json文件名:

    line39:类别名字和类别id改成自己的

    3、dataset_factory.py修改

    将数据集加入CenterNet-master/src/lib/datasets/dataset_factory.py

    Line14 添加:from .dataset.my_test import my_test

    Line29添加: ‘my_test':my_test

     

    4、/src/lib/opts.py修改

    加入自己数据集

     

    line336: 修改ctdet任务使用的默认数据集为新添加的数据集,如下(修改分辨率,类别数,均值,方差,数据集名字):

     

    5、CenterNet-master/src/lib/utils/debugger.py修改

    Line 45添加:

    Line 458添加:

     

    6、训练数据

    参数说明:

    (1)arch:代表选择的backbone的类型

    (2)img_size:控制图片长和宽

    (3)lr和lr_step:控制学习率大小及变化

    (4)batch_size:一个批次处理的图片个数

    (5)num_epochs:学习数据集的总次数

    (6)num_works:开启多少个线程加载数据集

    在src目录下使用命令:

    python main.py ctdet --dataset my_test --exp_id my_test --batch_size 4 --lr 0.001 --gpus 1 --num_workers 4

    或者:nohup python main.py ctdet --dataset my_test --exp_id my_test --batch_size 4 --lr 0.001 --gpus 1 --num_workers 4 > nohup3.out 2>&1 &

    默认是迭代140轮完成训练,如果嫌时间太久了,可以修改opts.py文件如下,指定运行5次就可以了,或者命令中带--num_epochs 5也是可以的。

     

    运行之后的日志如下图所示:





    查看linux进程,会发现有5个正在跑的进程,因为指定了4个work同时进行训练,这4个work是多线程进行图片加载,另外1个是在训练模型。

    模型运行完毕之后,会在exp/ctet/my_test下生成两个pth模型文件。

    7、生成训练的loss曲线图

    按照上一步进行训练后,会在exp/ctdet/my_test/logs_xxx的目录下生成log.txt文件。

    里面的数据只记录每一轮迭代完之后的loss信息,progressbar每一张的loss数据是不会写进log.txt文件的。

    具体的信息如下图所示:

     

    然后就是读取上面的信息,生成loss的曲线图,参考代码如下:

    1. import matplotlib.pyplot as plt
    2. def plot_loss_curve(log_file):
    3. loss_data = open(log_file)
    4. all_lines = loss_data.readlines()
    5. print(all_lines[4].split(' '))
    6. total_loss = []
    7. hm_loss = []
    8. wh_loss = []
    9. off_loss = []
    10. val_loss = []
    11. spend_time = []
    12. num_lines = len(all_lines)
    13. for line in range(num_lines):
    14. total_loss1 = all_lines[line].split(' ')[4]
    15. hm_loss1 = all_lines[line].split(' ')[7]
    16. wh_loss1 = all_lines[line].split(' ')[10]
    17. off_loss1 = all_lines[line].split(' ')[13]
    18. spend_time1 = all_lines[line].split(' ')[16]
    19. print(total_loss1)
    20. print(spend_time1)
    21. total_loss.append(float(total_loss1))
    22. hm_loss.append(float(hm_loss1))
    23. wh_loss.append(float(wh_loss1))
    24. off_loss.append(float(off_loss1))
    25. spend_time.append(float(spend_time1))
    26. return total_loss
    27. if __name__ == '__main__':
    28. loss_res18 = plot_loss_curve("D:\\temp\\centernet_log.txt")
    29. fig = plt.figure(figsize=(104))
    30. ax = fig.add_subplot(111)
    31. ax.plot(range(len(loss_res18)), loss_res18, 'c', label='building', linewidth=1)
    32. ax.set_xlim([16])
    33. ax.set_xticks(range(051))
    34. ax.set_yticklabels(['jan''feb''mar'])
    35. ax.set_xlabel('epochs')
    36. ax.set_ylabel('loss_value')
    37. ax.text(875020"plane", color='red')
    38. ax.set_title('loss_of_CenterNet')
    39. ax.legend(loc='best')
    40. ax.grid()
    41. plt.show()

    生成的曲线如下图所示:

     

    8、图片目标检测操作

    从网上找一张船的图片放到images目录下,然后运行如下命令:

    python src/demo.py ctdet --demo images/001987.jpg --load_model /workspace/hugh/CenterNet-master/exp/ctdet/my_test/model_best.pth --vis_thresh 0.1

    上一步模型训练只迭代了5轮,模型准确度是不高的,如果不设置vis_thresh会导致图片中检测不到目标。

     

  • 相关阅读:
    七种交换变量值的方法,看看你知道几种
    NIO的浅了解
    如何本地部署开源AI知识库 FastGPT(新手教程)
    计算性能的提升之异步计算与并行计算(MXNet)
    SpringBoot使用log4j2将日志记录到文件及自定义数据库
    匿名页的反向映射
    吊死人小游戏 2.0版本
    Python之三大基本库——Numpy(1)
    商业化广告--体系学习-- 16 -- 业务实战篇 --需求调研:广告产品潜在需求的调研流程是怎样的?
    MySQL客户端工具的使用与MySQL SQL语句
  • 原文地址:https://blog.csdn.net/benben044/article/details/126525532