• DAMO-YOLO训练KITTI数据集


    1.KITTI数据集准备

    DAMO-YOLO支持COCO格式的数据集,在训练KITTI之前,需要将KITTI的标注转换为KITTI格式。KITTI是采取逐个文件标注的方式确定的,即一张图片对应一个label文件。下面是KITTI 3D目标检测训练集的第一个标注文件:000000.txt

    Pedestrian 0.00 0 -0.20 712.40 143.00 810.73 307.92 1.89 0.48 1.20 1.84 1.47 8.41 0.01

    就不一一解释了,引用一下KITTI 3D目标检测数据集解析(完整版)_kitti数据集结构-CSDN博客的表格,可以看一下该文章的详细解释

    在这里插入图片描述

    COCO格式就不再详细解释了,可以看一下这篇文章COCO数据集(目标检测任务json文件内容总结,总结一下COCO格式需要一个大的json文件,里面包含了每个图片的路径及注释,bbox由中心坐标和宽高的形式给出。

    因此,KITTI格式转COCO,就是需要读取逐个的txt文件,进行坐标换算后写入json,代码如下,基于TXT 转成COCO jason格式的标注_D_galaxy的博客-CSDN博客修改:

    import cv2
    from math import *
    import numpy as np
    import os, random, shutil
    import glob as gb
    from time import sleep
    import copy
    import json
    
    
    def copyFile2Folder(srcfile, dstfolder):
        '''
        复制文件到指定文件夹,名字和以前相同
        Args:
            srcfile: '/home/wsd/***/yolov5/data/PCB_DATASET/labels/Spur/04_spur_06.txt'  文件的绝对路径
            dstfile: '/home/wsd/***/yolov5/data/PCB_DATASET/train/labels'  文件夹
    
        Returns:
    
        '''
    
        if not os.path.isfile(srcfile):
            print("%s not exist!" % (srcfile))
    
        else:
            src_fpath, src_fname = os.path.split(srcfile)  # 分离文件名和路径
            if not os.path.exists(dstfolder):
                os.makedirs(dstfolder)  # 创建路径dst_file
    
            dst_file = os.path.join(dstfolder, src_fname)
            shutil.copyfile(srcfile, dst_file)  # 复制文件
            print("copy %s -> %s" % (srcfile, dst_file))
            return dst_file
    
    
    class cocoJsaon(object):
        '''
       coco 的json 的文件格式类
       '''
    
        def __init__(self, categories):
            self.info = {'description': 'PCB DATASET',
                         'url': 'DLH',
                         'version': '1.0',
                         'year': 2021,
                         'contributor': 'DLHgroup',
                         'date_created': '2021-01-12 16:11:52.357475'
                         }
            self.license = {
                "url": "none",
                "id": 1,
                "name": "Attribution-NonCommercial-ShareAlike License"}
    
            self.images = None
            self.annotations = None
            self.category = categories
            self.cocoJasonDict = {"info": self.info, "images": self.images, "annotations": self.annotations,
                                  "licenses": self.license, 'categories': self.category}
    
        def getDict(self):
            '''
    
          Returns: 返回 格式的字典化
    
          '''
    
            self.cocoJasonDict = {"info": self.info, "images": self.images, "annotations": self.annotations,
                                  "licenses": self.license, 'categories': self.category}
            return self.cocoJasonDict
    
    
    if __name__ == '__main__':
    
        # 文件原本:
        '''
        root: /home/dlh/opt/***/PCB_DATASET
                        ------------------->labels/  # 原本的所有目标检测框的   *.txt
                        ------------------->images/   #  *.jpg  所有的图片
                        ------------------->ImageSets/  # train.txt  和  val.txt
                        ------------------->annotations  /  存放 labels 下所有对应的train.json
    
        最终:
        root: /home/dlh/opt/***/PCB_DATASET/PCB   
                            ------------------->images/   #  *.jpg  所有的图片
                            ------------------->annotations  /  instances_train_set_name.json   # 存放 labels 下所有对应的train.json
                                                             /  instances_val_set_name.json     # 存放 labels val.json
    
    
        '''
    
        # 写入的train 还是Val 
        wrtie_str = 'train'
        train_path = '/home/wistful/Datasets/KITTI/training/'
        # 存放 train.txt  和  val.txt  的绝对地址    (修改)
        Imageset = '/home/wistful/Datasets/KITTI/ImageSets/' + wrtie_str + '.txt'
        # 存放 即将所有的原本图片  保存到 该地址  临时       (修改)
        tarset = '/home/wistful/Datasets/KITTI/training/image_2/' + wrtie_str + '_set_name'
        # 下面是更改 json 文件 的
        tempDir = Imageset.replace('txt', 'json')
        tempDir = tempDir.replace('ImageSets', 'annotations')
        jsonFile = tempDir.replace(wrtie_str, 'instances_' + wrtie_str + '_set_name')
        jasonDir, _ = os.path.split(jsonFile)
        # 告诉你 最新的Jason 文件保存到了那里
        print(f'jsonFile saved {jsonFile}')
    
        # 检查目标文件夹是否存在
        if not os.path.exists(tarset):
            os.makedirs(tarset)
        if not os.path.exists(jasonDir):
            os.makedirs(jasonDir)
    
        # images 段 的字典模板
        images = {"license": 3,
                  "file_name": "COCO_val2014_000000391895.jpg",
                  "coco_url": "",
                  "height": 360, "width": 640, "date_captured": "2013-11-14 11:18:45",
                  "id": 0}
    
        # annotation 段 的字典模板
        an = {"segmentation": [],
              "iscrowd": 0,
              "keypoints": 0,
              "area": 10.0,
              "image_id": 0, "bbox": [], "category_id": 0,
              "id": 0}
    
        # categories 段 的字典模板
        cate_ = {
            'id': 0,
            'name': 'a',
        }
    
        # 用来保存目标类的  字典
        cate_list = []
        # 你的目标类有几个  (修改)
        className = ['Pedestrian', 'Car', 'Cyclist', 'DontCare']
        carName = ['Car', 'Van', 'Truck', 'Tram']  # 车的类型,最终会归一成car
        personName = ['Pedestrian', 'Person_sitting']  # 人员类型,最终归一为Pedestrian
        dontCare = ['Misc', 'DontCare']  # 杂物,归一为DontCare
    
        temId = 0
        for idName in className:
            tempCate = cate_.copy()
            tempCate['id'] = temId
            temId += 1
            tempCate['name'] = idName
    
            cate_list.append(tempCate)
    
        # print(cate_list)
        # 创建coco json 的类 实例
        js = cocoJsaon(cate_list)
    
        image_lsit = []
        annoation_list = []
    
        # 打开 train。txt
        with open(Imageset, 'r') as f:
            lines = f.readlines()
            # print(f'imageset lines:{lines}')
    
        img_id = 0
        bbox_id = 0
        # 按值去打开图片
        for path in lines:
            # 我的train.txt 是按照绝对路径保存的,各位需要的根据自己的实际情况来修改这里的代码
            # 去出  \n 之类的空格
            path = path.lstrip().rstrip()
            # 得到图像文件路径
            img_path = train_path + 'image_2/' + path + '.png'
            # print(f'path:{path}')
            # 打开图片
            image = cv2.imread(img_path)
            # 将这个图片副知道新的文件夹  (以实际情况  修改)
            copyFile2Folder(img_path, tarset)
            # 得到宽高
            (height, width) = image.shape[:2]
            # (height, width) = 375, 1242
            # 得到文件名子
            _, fname = os.path.split(img_path)
            # print(f'_ and name:{_, fname}')
            # 图像对应的txt 文件路径
            txtPath = train_path + 'label_2/' + path + '.txt'
    
            # print(f'txtPath:{txtPath}')
            # 复制images 的字典的复制
            image_temp = images.copy()
            image_temp['file_name'] = fname
            image_temp['height'] = height
            image_temp['width'] = width
            image_temp['id'] = img_id
            # 将其放入到集合中
            image_lsit.append(image_temp)
            # 打开图片的对应的txt 目标文件的txt
            # print(f'txt path:{txtPath}')
            with open(txtPath, 'r') as re:
                txtlines = re.readlines()
                for txtline in txtlines:
                    # 去出  \n 之类的空格
                    temp = txtline.rstrip().lstrip().split(' ')
                    # print(f'temp:{temp}')
                    # 分别的到 目标的类 中心值 xy  和  该检测框的宽高
                    if temp[0] in carName:
                        classid = className.index('Car')
                    elif temp[0] in personName:
                        classid = className.index('Pedestrian')
                    elif temp[0] in dontCare:
                        classid = className.index('DontCare')
                    else:
                        classid = className.index('Cyclist')
                    # classid = className.index(temp[0])  # 获取id
                    # 计算宽高及中心
                    w = float(temp[6]) - float(temp[4])
                    h = float(temp[7]) - float(temp[5])
                    x = (float(temp[6]) + float(temp[4]))/2
                    y = (float(temp[7]) + float(temp[5]))/2
                   
                    iscrowd = int(temp[2])
                    # 判断是否遮挡
                    if iscrowd != 0:
                        iscrowd = 1
                    # 计算面积
                    area = w * h
                    # 复制annotation 的字典
                    temp_an['area'] = area
                    temp_an = an.copy()
                    temp_an['image_id'] = img_id
                    temp_an['bbox'] = [x, y, w, h]
                    temp_an['iscrowd'] = iscrowd
                    temp_an['category_id'] = classid
                    temp_an['id'] = bbox_id
                    bbox_id += 1  # 这个是 这个annotations 的id 因为一个图像可能对应多个 目标的id
                    annoation_list.append(temp_an)
            # 图像的id
            img_id += 1
    
        # print(js.getDict())
        # print('***********************************************************************\n\n')
        # 将json 的实例 中images  赋值
        js.images = image_lsit
        # 将json 的实例 annotations  赋值
        js.annotations = annoation_list
        # 写入文件
        json_str = json.dumps(js.getDict())
        with open(jsonFile, 'w+') as ww:
            ww.write(json_str)
    
        print('finished')
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
    • 168
    • 169
    • 170
    • 171
    • 172
    • 173
    • 174
    • 175
    • 176
    • 177
    • 178
    • 179
    • 180
    • 181
    • 182
    • 183
    • 184
    • 185
    • 186
    • 187
    • 188
    • 189
    • 190
    • 191
    • 192
    • 193
    • 194
    • 195
    • 196
    • 197
    • 198
    • 199
    • 200
    • 201
    • 202
    • 203
    • 204
    • 205
    • 206
    • 207
    • 208
    • 209
    • 210
    • 211
    • 212
    • 213
    • 214
    • 215
    • 216
    • 217
    • 218
    • 219
    • 220
    • 221
    • 222
    • 223
    • 224
    • 225
    • 226
    • 227
    • 228
    • 229
    • 230
    • 231
    • 232
    • 233
    • 234
    • 235
    • 236
    • 237
    • 238
    • 239
    • 240
    • 241
    • 242
    • 243
    • 244
    • 245
    • 246
    • 247
    • 248
    • 249

    上述代码,只需更改主函数开头的wrtie_strtrain_path就行了,代码不难理解

    2.修改DAMO-YOLO的配置文件

    • 修改damo/config/paths_catalog.py,将coco_2017_traincoco_2017_val的相关路径修改一下
    # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
    # Copyright (C) Alibaba Group Holding Limited. All rights reserved.
    """Centralized catalog of paths."""
    import os
    
    
    class DatasetCatalog(object):
        DATA_DIR = 'datasets'
        DATASETS = {
            'coco_2017_train': {
                'img_dir': '/home/wistful/Datasets/KITTI/training/image_2',
                'ann_file': '/home/wistful/Datasets/KITTI/annotations/instances_train_set_name.json' # 第一步生成的json文件
            },
            'coco_2017_val': {
                'img_dir': '/home/wistful/Datasets/KITTI/training/image_2',
                'ann_file': '/home/wistful/Datasets/KITTI/annotations/instances_val_set_name.json'
            },
            'coco_2017_test_dev': {
                'img_dir': '/home/wistful/Datasets/KITTI/training/image_2',
                'ann_file': '/home/wistful/Datasets/KITTI/annotations/instances_val_set_name.json'
            },
        }
    
        @staticmethod
        def get(name):
            if 'coco' in name:
                data_dir = DatasetCatalog.DATA_DIR
                attrs = DatasetCatalog.DATASETS[name]
                args = dict(
                    root=os.path.join(data_dir, attrs['img_dir']),
                    ann_file=os.path.join(data_dir, attrs['ann_file']),
                )
                return dict(
                    factory='COCODataset',
                    args=args,
                )
            else:
                raise RuntimeError('Only support coco format now!')
            return None
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 修改配置文件中的ZeroHeadself.dataset.class_names

    在这里插入图片描述

    • 导入预训练权重及修改训练轮数,GitHub官方仓库里有,就不介绍了

    3.训练

    python -m torch.distributed.launch --nproc_per_node=2 tools/train.py -f configs/damoyolo_tinynasL20_T.py
    
    • 1

    2指的是gpu个数,-f 后面是配置文件

    在这里插入图片描述

  • 相关阅读:
    C 语言标准库
    「面经分享」小米java岗二面面经,已拿offer
    CMake变量可见性学习
    如何做好接口自动化测试?
    oracle灾备切换和回切步骤以及sql执行语句
    Dreambooth工作原理
    Linux基础概念--进程、子进程、进程组和会话
    TCP/IP、DTN网络通信协议族
    java实现幂等性校验
    开环模块化多电平换流器仿真(MMC)N=6(Simulink仿真)
  • 原文地址:https://blog.csdn.net/u014295602/article/details/133309331