本文参考:
1、数据集相关的:https://blog.csdn.net/weixin_42634342/article/details/97697356
2、训练自己的模型参考:
https://bbs.huaweicloud.com/blogs/210374
https://www.huaweicloud.com/articles/ebb05fa50237d7ac7ad6c0b29e38f969.html
https://blog.csdn.net/jiangpeng59/article/details/105732166
3、CenterNet官方安装文档:https://github.com/xingyizhou/CenterNet/blob/master/readme/INSTALL.md
一、数据集处理
1、训练集下载
SeaShips数据集,链接为:http://www.lmars.whu.edu.cn/prof_web/shaozhenfeng/datasets/SeaShips(7000).zip,
如果在linux上直接wget获取即可
该版本数据集共有7000张图片,图片分辨率均为1920*1080,分为六类船只,主要是一些内河道中船只的图片。
该数据集为PascalVOC数据集。
2、voc数据集格式转coco格式
centernet虽然同时支持coco和voc数据集,但是本次我们需要转成coco格式的数据集,方便后续进行操作
转换的参考代码为:https://blog.csdn.net/yang332233/article/details/97205112
只需要修改最后几行加粗的位置相关的代码即可。
转换完毕之后的json文件后续会放到centernet对应的目录下面
二、CenterNet代码编译
1、安装python3.6环境
2、安装pytorch 1.x以及pytorchvision 0.x版本
3、安装cocoapi
# COCOAPI=/path/to/clone/cocoapi
git clone https://github.com/cocodataset/cocoapi.git $COCOAPI
cd $COCOAPI/PythonAPI
make
python setup.py install --user
4、下载CenterNet
CenterNet_ROOT=/path/to/clone/CenterNet
git clone https://github.com/xingyizhou/CenterNet $CenterNet_ROOT
5、安装CenterNet依赖的python包
pip install -r requirements.txt
6、编译DCNv2
CenterNet自带的DCNv2只支持pytorch0.4,随意会导致后续编译不成功,所以需要删除$CenterNet_ROOT/src/lib/models/networks/DCNv2目录,然后重新下载最新的DCNv2代码进行编译。
对应的git地址是:git clone https://github.com/CharlesShang/DCNv2.git
下载完成后进行编译:
./make.sh
7、编译NMS组件
cd $CenterNet_ROOT/src/lib/external
make
三、使用已有的线上模型进行预测
1、下载训练好的模型
模型下载地址见:https://github.com/xingyizhou/CenterNet/blob/master/readme/MODEL_ZOO.md,
如果做目标检测,可以下载下图所指的模型,这个模型大概是77M。
模型下载后放到models目录下即可。
2、修改结果输出方式
centernet默认是把预测结果图片输出到screen,但是docker中无法显示导致报错,所以需要修改输出方式。
修改 src/lib/detectors/cdet.py。
将debugger.show_all_imgs(pause=self.pause) 注释掉,
换成debugger.save_all_imgs(path='/home/jhsu/sujh/ljj/CenterNet/output', genID=True)
如下图所示:
3、图片进行目标检测
随便网上找一张图片,执行以下命令:
python src/demo.py ctdet --demo images/17790319373_bd19b24cfc_k.jpg --load_model models/ctdet_coco_dla_2x.pth
运行出错可参考:http://blog.sina.com.cn/s/blog_628cc2b70102ysyi.html,相关问题可以参考进行解决
执行完毕后会在output目录下输出结果,一般是xctdet.png的文件。得到的结果为:
四、使用已有的数据集进行训练
1、存放数据集
在CenterNet主目录的data目录下创建MyDataTest,如下图所示:
annotations目录中存放第一步的json文件,比如我只有train.json文件
images目录存放7000张jpg文件
2、在CenterNet-master/src/lib/datasets/dataset/文件夹里面,复制coco.py并从命名为my_test.py
打开my_test.py修改:
line13:class COCO修改成class my_test
line14:num_classes = 6 #注意这里不包含背景类,只有6种船的类型
line15:default_resolution = [512, 512] 修改自己需要的训练图片大小,虽然我们的图片是1920*1080,但是无需修改
line16,18:均值方差改自己的,或者也可以不改
Line22:super(COCO, self).init()里面的COCO换成自己的类名my_test
Line23,24:修改自己的数据路径
line26-37:修改自己json文件名:
line39:类别名字和类别id改成自己的
3、dataset_factory.py修改
将数据集加入CenterNet-master/src/lib/datasets/dataset_factory.py
Line14 添加:from .dataset.my_test import my_test
Line29添加: ‘my_test':my_test
4、/src/lib/opts.py修改
加入自己数据集
line336: 修改ctdet任务使用的默认数据集为新添加的数据集,如下(修改分辨率,类别数,均值,方差,数据集名字):
5、CenterNet-master/src/lib/utils/debugger.py修改
Line 45添加:
Line 458添加:
6、训练数据
参数说明:
(1)arch:代表选择的backbone的类型
(2)img_size:控制图片长和宽
(3)lr和lr_step:控制学习率大小及变化
(4)batch_size:一个批次处理的图片个数
(5)num_epochs:学习数据集的总次数
(6)num_works:开启多少个线程加载数据集
在src目录下使用命令:
python main.py ctdet --dataset my_test --exp_id my_test --batch_size 4 --lr 0.001 --gpus 1 --num_workers 4
或者:nohup python main.py ctdet --dataset my_test --exp_id my_test --batch_size 4 --lr 0.001 --gpus 1 --num_workers 4 > nohup3.out 2>&1 &
默认是迭代140轮完成训练,如果嫌时间太久了,可以修改opts.py文件如下,指定运行5次就可以了,或者命令中带--num_epochs 5也是可以的。
运行之后的日志如下图所示:
查看linux进程,会发现有5个正在跑的进程,因为指定了4个work同时进行训练,这4个work是多线程进行图片加载,另外1个是在训练模型。
模型运行完毕之后,会在exp/ctet/my_test下生成两个pth模型文件。
7、生成训练的loss曲线图
按照上一步进行训练后,会在exp/ctdet/my_test/logs_xxx的目录下生成log.txt文件。
里面的数据只记录每一轮迭代完之后的loss信息,progressbar每一张的loss数据是不会写进log.txt文件的。
具体的信息如下图所示:
然后就是读取上面的信息,生成loss的曲线图,参考代码如下:
- import matplotlib.pyplot as plt
-
- def plot_loss_curve(log_file):
- loss_data = open(log_file)
- all_lines = loss_data.readlines()
- print(all_lines[4].split(' '))
- total_loss = []
- hm_loss = []
- wh_loss = []
- off_loss = []
- val_loss = []
- spend_time = []
- num_lines = len(all_lines)
-
- for line in range(num_lines):
- total_loss1 = all_lines[line].split(' ')[4]
- hm_loss1 = all_lines[line].split(' ')[7]
- wh_loss1 = all_lines[line].split(' ')[10]
- off_loss1 = all_lines[line].split(' ')[13]
- spend_time1 = all_lines[line].split(' ')[16]
-
- print(total_loss1)
- print(spend_time1)
-
- total_loss.append(float(total_loss1))
- hm_loss.append(float(hm_loss1))
- wh_loss.append(float(wh_loss1))
- off_loss.append(float(off_loss1))
- spend_time.append(float(spend_time1))
-
- return total_loss
-
- if __name__ == '__main__':
- loss_res18 = plot_loss_curve("D:\\temp\\centernet_log.txt")
- fig = plt.figure(figsize=(10, 4))
- ax = fig.add_subplot(111)
- ax.plot(range(len(loss_res18)), loss_res18, 'c', label='building', linewidth=1)
- ax.set_xlim([1, 6])
- ax.set_xticks(range(0, 5, 1))
- ax.set_yticklabels(['jan', 'feb', 'mar'])
- ax.set_xlabel('epochs')
- ax.set_ylabel('loss_value')
- ax.text(8750, 20, "plane", color='red')
- ax.set_title('loss_of_CenterNet')
- ax.legend(loc='best')
- ax.grid()
- plt.show()
生成的曲线如下图所示:
8、图片目标检测操作
从网上找一张船的图片放到images目录下,然后运行如下命令:
python src/demo.py ctdet --demo images/001987.jpg --load_model /workspace/hugh/CenterNet-master/exp/ctdet/my_test/model_best.pth --vis_thresh 0.1
上一步模型训练只迭代了5轮,模型准确度是不高的,如果不设置vis_thresh会导致图片中检测不到目标。