• detectron2环境搭建及自定义coco数据集(voc转coco)训练


    detectron2建议ubuntu进行环境搭建,Windows大概率报错

    环境搭建

    创建虚拟环境

    conda create -n detectron2 python=3.8 -y
    conda activate detectron2
    
    • 1
    • 2

    后面下载源代码建议存到git中再git clone

    git clone https://github.com/facebookresearch/detectron2.git
    python -m pip install -e detectron2
    
    • 1
    • 2

    二 数据集构建

    首先有两个文件夹,一个为图像,一个为标注完的voc的xml文件夹
    在这里插入图片描述
    在这里插入图片描述
    后先划分数据集,创建以下格式

    -- imageSets
    	---Main
    		----test.txt
    		----train.txt
    		----trainval.txt
    		----val.txt
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    运行以下代码进行划分

    import os
    import random
    def main():
        trainval_percent = 0.2
        train_percent = 0.8
        xmlfilepath = 'D:\\lunwen\\data\\data\\xml' # xml文件位置
        txtsavepath = r'D:\\lunwen\\data\\data\\image'  # 图像位置
        total_xml = os.listdir(xmlfilepath)
    
        num = len(total_xml)
        list = range(num)
        tv = int(num * trainval_percent)
        tr = int(tv * train_percent)
        trainval = random.sample(list, tv)
        train = random.sample(trainval, tr)
    ## 以下为对应的txt文件进行修改
        ftrainval = open('D:\\lunwen\\data\data\\ImageSets\\Main\\trainval.txt', 'w')
        ftest = open('D:\\lunwen\\data\data\\ImageSets\\Main\\test.txt', 'w')
        ftrain = open('D:\\lunwen\\data\data\\ImageSets\\Main\\train.txt', 'w')
        fval = open('D:\\lunwen\\data\data\\ImageSets\\Main\\val.txt', 'w')
    
        for i in list:
            name = total_xml[i][:-4] + '\n'
            if i in trainval:
                ftrainval.write(name)
                if i in train:
                    ftest.write(name)
                else:
                    fval.write(name)
            else:
                ftrain.write(name)
    
        ftrainval.close()
        ftrain.close()
        fval.close()
        ftest.close()
    
    
    
    if __name__=='__main__':
        main();
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41

    划分完后对应的txt文件都有数据,后面将进行文件提取,提取出对应的训练/测试数据集及其xml文件

    源目录下创建以下文件夹

    -train_JPEGImages
    -train_annotations
    -val_JPEGImages
    -val_annotations
    
    • 1
    • 2
    • 3
    • 4
    import os
    import shutil
    
    
    class CopyXml():
        def __init__(self):
            # 你的xml格式的annotation的路径
            self.xmlpath = r'D:\lunwen\data\data\xml'
            self.jpgpath = r'D:\lunwen\data\data\image'
            # 你训练集/测试集xml和jpg存放的路径
            self.newxmlpath = r'D:\lunwen\data\data\val_annotations'
            self.newjpgpath = r'D:\lunwen\data\data\val_JPEGImages'
    
        def startcopy(self):
            filelist = os.listdir(self.xmlpath)  # file list in this directory
            # print(len(filelist))
            test_list = loadFileList()
            # print(len(test_list))
            for f in filelist:
                xmldir = os.path.join(self.xmlpath, f)
                (shotname, extension) = os.path.splitext(f)
                jpgdir = os.path.join(self.jpgpath, shotname+'.jpg')
                if str(shotname) in test_list:
                    # print('success')
                    shutil.copyfile(str(xmldir), os.path.join(self.newxmlpath, f))
                    shutil.copyfile(str(jpgdir), os.path.join(self.newjpgpath, shotname+'.jpg'))
    
    
    # load the list of train/test file list
    def loadFileList():
        filelist = []
        ## 选取你上一步弄好的txt文件进行提取
        f = open("D:\\lunwen\\data\\data\\ImageSets\\Main\\val.txt", "r")
        # f = open("VOC2007/ImageSets/Main/train.txt", "r")
        lines = f.readlines()
        for line in lines:
            # 去掉文件中每行的结尾字符
            line = line.strip('\r\n')  # to remove the '\n' for test.txt, '\r\n' for tainval.txt
            line = str(line)
            filelist.append(line)
        f.close()
        # print(filelist)
        return filelist
    
    
    if __name__ == '__main__':
        demo = CopyXml()
        demo.startcopy()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48

    成功后对应文件夹下将会有对应txt下的图片及xml
    后面便是voc转coco格式了,代码如下

    # coding:utf-8
    import sys
    import os
    import json
    import xml.etree.ElementTree as ET
    
    START_BOUNDING_BOX_ID = 1
    # 注意下面的dict存储的是实际检测的类别,需要根据自己的实际数据进行修改
    # 这里以自己的数据集person和hat两个类别为例,如果是VOC数据集那就是20个类别
    # 注意类别名称和xml文件中的标注名称一致
    PRE_DEFINE_CATEGORIES = {"Normal": 0, "No_wear": 1,"Playing_phone":2,"Smoking":3,"Sleeping":4,
                             "Playing_phone_smoking":5,"No_wear_playing_phone":6,"No_wear_smoking":7,"No_wear_sleeping":8,
                             "No_wear_playing_phone_smoking":9}
    
    # 注意按照自己的数据集名称修改编号和名称
    def get(root, name):
        vars = root.findall(name)
        return vars
    
    def get_and_check(root, name, length):
        vars = root.findall(name)
        if len(vars) == 0:
            raise NotImplementedError('Can not find %s in %s.' % (name, root.tag))
        if length > 0 and len(vars) != length:
            raise NotImplementedError('The size of %s is supposed to be %d, but is %d.' % (name, length, len(vars)))
        if length == 1:
            vars = vars[0]
        return vars
    
    def get_filename_as_int(filename):
        try:
            filename = os.path.splitext(filename)[0]
            return (filename)
        except:
            raise NotImplementedError('Filename %s is supposed to be an integer.' % (filename))
    
    
    def convert(xml_dir, json_file):
        xmlFiles = os.listdir(xml_dir)
    
        json_dict = {"images": [], "type": "instances", "annotations": [],
                     "categories": []}
        categories = PRE_DEFINE_CATEGORIES
        bnd_id = START_BOUNDING_BOX_ID
        num = 0
        for line in xmlFiles:
            #         print("Processing %s"%(line))
            num += 1
            if num % 50 == 0:
                print("processing ", num, "; file ", line)
    
            xml_f = os.path.join(xml_dir, line)
            tree = ET.parse(xml_f)
            root = tree.getroot()
            # The filename must be a number
            filename = line[:-4]
            image_id = get_filename_as_int(filename)
            size = get_and_check(root, 'size', 1)
            width = int(get_and_check(size, 'width', 1).text)
            height = int(get_and_check(size, 'height', 1).text)
            # image = {'file_name': filename, 'height': height, 'width': width,
            #          'id':image_id}
            image = {'file_name': (filename + '.jpg'), 'height': height, 'width': width,
                     'id': image_id}
            json_dict['images'].append(image)
            # Cruuently we do not support segmentation
            #  segmented = get_and_check(root, 'segmented', 1).text
            #  assert segmented == '0'
            for obj in get(root, 'object'):
                category = get_and_check(obj, 'name', 1).text
                if category not in categories:
                    new_id = len(categories)
                    categories[category] = new_id
                category_id = categories[category]
                bndbox = get_and_check(obj, 'bndbox', 1)
                xmin = int(get_and_check(bndbox, 'xmin', 1).text) - 1
                ymin = int(get_and_check(bndbox, 'ymin', 1).text) - 1
                xmax = int(get_and_check(bndbox, 'xmax', 1).text)
                ymax = int(get_and_check(bndbox, 'ymax', 1).text)
                assert (xmax > xmin)
                assert (ymax > ymin)
                o_width = abs(xmax - xmin)
                o_height = abs(ymax - ymin)
                ann = {'area': o_width * o_height, 'iscrowd': 0, 'image_id':
                    image_id, 'bbox': [xmin, ymin, o_width, o_height],
                       'category_id': category_id, 'id': bnd_id, 'ignore': 0,
                       'segmentation': []}
                json_dict['annotations'].append(ann)
                bnd_id = bnd_id + 1
    
        for cate, cid in categories.items():
            cat = {'supercategory': 'none', 'id': cid, 'name': cate}
            json_dict['categories'].append(cat)
        json_fp = open(json_file, 'w')
        json_str = json.dumps(json_dict)
        json_fp.write(json_str)
        json_fp.close()
    
    if __name__ == '__main__':
        folder_list = ["train_annotations"]
        # 注意更改base_dir为本地实际图像和标注文件路径
        base_dir = "D:\\lunwen\\data\\data\\"
        # 修改为自己的路径
        for i in range(1):
            folderName = folder_list[i]
            xml_dir = base_dir + folderName
            json_dir = base_dir + folderName + "/instances_" + folderName + ".json"
            print("deal: ", folderName)
            print("xml dir: ", xml_dir)
            print("json file: ", json_dir)
            convert(xml_dir, json_dir)
    
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113

    最后生成
    在这里插入图片描述

    三 模型训练

    方法一

    detectron2下datasets中创建coco文件夹,将对应数据放入

    coco
    
    ----train2017 ####手动创建
    
    ----val2017 ####手动创建
    
    ----annotations ####手动创建
    
    --------instances_train2017.json ####脚本生成
    
    --------instances_val2017.json ####脚本生成
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11

    因为pytorch训练自己的数据集,涉及到数据集的注册,元数据集注册和加载,过程比较麻烦,这里我参考官方样本,写了一个脚本trainsample.py放置于model_train文件夹下。

    import os
    import cv2
    import logging
    from collections import OrderedDict
    
    import detectron2.utils.comm as comm
    from detectron2.utils.visualizer import Visualizer
    from detectron2.checkpoint import DetectionCheckpointer
    from detectron2.config import get_cfg
    from detectron2.data import DatasetCatalog, MetadataCatalog
    from detectron2.data.datasets.coco import load_coco_json
    from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, launch
    from detectron2.evaluation import COCOEvaluator, verify_results
    from detectron2.modeling import GeneralizedRCNNWithTTA
    
    # 数据集路径
    DATASET_ROOT = './datasets/coco'
    ANN_ROOT = os.path.join(DATASET_ROOT, 'annotations')
    
    TRAIN_PATH = os.path.join(DATASET_ROOT, 'train2017')
    VAL_PATH = os.path.join(DATASET_ROOT, 'val2017')
    
    TRAIN_JSON = os.path.join(ANN_ROOT, 'instances_train2017.json')
    # VAL_JSON = os.path.join(ANN_ROOT, 'val.json')
    VAL_JSON = os.path.join(ANN_ROOT, 'instances_val2017.json')
    
    CLASS_NAMES = ['Normal','No_wear','Playing_phone','Smoking','Sleeping','Playing_phone_smoking','No_wear_playing_phone'
        ,'No_wear_smoking','No_wear_sleeping','No_wear_playing_phone_smoking']
    # 数据集类别元数据
    DATASET_CATEGORIES = [
        # {"name": "background", "id": 0, "isthing": 1, "color": [220, 20, 60]},
        {"name": "Normal", "id": 0, "isthing": 1, "color": [255, 0, 0]},  # 红色
        {"name": "No_wear", "id": 1, "isthing": 1, "color": [0, 255, 0]},  # 绿色
        {"name": "Playing_phone", "id": 2, "isthing": 1, "color": [0, 0, 255]},  # 蓝色
        {"name": "Smoking", "id": 3, "isthing": 1, "color": [255, 255, 0]},  # 黄色
        {"name": "Sleeping", "id": 4, "isthing": 1, "color": [255, 0, 255]},  # 紫色
        {"name": "Playing_phone_smoking", "id": 5, "isthing": 1, "color": [0, 255, 255]},  # 青色
        {"name": "No_wear_playing_phone", "id": 6, "isthing": 1, "color": [128, 0, 128]},  # 深紫色
        {"name": "No_wear_smoking", "id": 7, "isthing": 1, "color": [128, 128, 0]},  # 深黄色
        {"name": "No_wear_sleeping", "id": 8, "isthing": 1, "color": [0, 128, 128]},  # 深青色
        {"name": "No_wear_playing_phone_smoking", "id": 9, "isthing": 1, "color": [128, 128, 128]}
    ]
    
    # 数据集的子集
    PREDEFINED_SPLITS_DATASET = {
        "train_2019": (TRAIN_PATH, TRAIN_JSON),
        "val_2019": (VAL_PATH, VAL_JSON),
    }
    
    
    def register_dataset():
        """
        purpose: register all splits of dataset with PREDEFINED_SPLITS_DATASET
        """
        for key, (image_root, json_file) in PREDEFINED_SPLITS_DATASET.items():
            register_dataset_instances(name=key,
                                       metadate=get_dataset_instances_meta(),
                                       json_file=json_file,
                                       image_root=image_root)
    
    
    def get_dataset_instances_meta():
        """
        purpose: get metadata of dataset from DATASET_CATEGORIES
        return: dict[metadata]
        """
        thing_ids = [k["id"] for k in DATASET_CATEGORIES if k["isthing"] == 1]
        thing_colors = [k["color"] for k in DATASET_CATEGORIES if k["isthing"] == 1]
        # assert len(thing_ids) == 2, len(thing_ids)
        thing_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(thing_ids)}
        thing_classes = [k["name"] for k in DATASET_CATEGORIES if k["isthing"] == 1]
        ret = {
            "thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id,
            "thing_classes": thing_classes,
            "thing_colors": thing_colors,
        }
        return ret
    
    
    def register_dataset_instances(name, metadate, json_file, image_root):
        """
        purpose: register dataset to DatasetCatalog,
                 register metadata to MetadataCatalog and set attribute
        """
        DatasetCatalog.register(name, lambda: load_coco_json(json_file, image_root, name))
        MetadataCatalog.get(name).set(json_file=json_file,
                                      image_root=image_root,
                                      evaluator_type="coco",
                                      **metadate)
    
    
    # 注册数据集和元数据
    def plain_register_dataset():
        DatasetCatalog.register("train_2019", lambda: load_coco_json(TRAIN_JSON, TRAIN_PATH, "train_2019"))
        MetadataCatalog.get("train_2019").set(thing_classes=CLASS_NAMES,
                                              json_file=TRAIN_JSON,
                                              image_root=TRAIN_PATH)
        DatasetCatalog.register("val_2019", lambda: load_coco_json(VAL_JSON, VAL_PATH, "val_2019"))
        MetadataCatalog.get("val_2019").set(thing_classes=CLASS_NAMES,
                                            json_file=VAL_JSON,
                                            image_root=VAL_PATH)
    
    
    # 查看数据集标注
    def checkout_dataset_annotation(name="train_2019"):
        dataset_dicts = load_coco_json(TRAIN_JSON, TRAIN_PATH, name)
        for d in dataset_dicts:
            img = cv2.imread(d["file_name"])
            visualizer = Visualizer(img[:, :, ::-1], metadata=MetadataCatalog.get(name), scale=1.5)
            vis = visualizer.draw_dataset_dict(d)
            cv2.imshow('show', vis.get_image()[:, :, ::-1])
            cv2.waitKey(0)
    
    
    register_dataset()
    
    checkout_dataset_annotation()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117

    slowfast源码中采取的模型架构为
    在这里插入图片描述
    在config中找到对应的yaml文件,并在trainsample.py进行设定

    后运行

    $ python3 model_train/trainsample.py
    
    • 1

    测试代码为

    python demo/demo.py --config-file /root/autodl-tmp/detectron2/output_trainsample/config.yaml --video-input detect/1.avi --output demo/4.avi --opts MODEL.WEIGHTS /root/autodl-tmp/detectron2/output_trainsample/model_final.pth
    
    • 1

    方法2

    同方法一 在detectron2下datasets中创建coco文件夹,将对应数据放入

    coco
    
    ----train2017 ####手动创建
    
    ----val2017 ####手动创建
    
    ----annotations ####手动创建
    
    --------instances_train2017.json ####脚本生成
    
    --------instances_val2017.json ####脚本生成
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11

    在detectron2目录下创建一个train_net.py文件

    内容为:

    #!/usr/bin/env python
    # Copyright (c) Facebook, Inc. and its affiliates.
    """
    A main training script.
    
    This scripts reads a given config file and runs the training or evaluation.
    It is an entry point that is made to train standard models in detectron2.
    
    In order to let one script support training of many models,
    this script contains logic that are specific to these built-in models and therefore
    may not be suitable for your own project.
    For example, your research project perhaps only needs a single "evaluator".
    
    Therefore, we recommend you to use detectron2 as an library and take
    this file as an example of how to use the library.
    You may want to write your own script with your datasets and other customizations.
    """
    
    import logging
    import os
    from collections import OrderedDict
    
    import detectron2.utils.comm as comm
    from detectron2.checkpoint import DetectionCheckpointer
    from detectron2.config import get_cfg
    from detectron2.data import MetadataCatalog
    from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, hooks, launch
    from detectron2.evaluation import (
        CityscapesInstanceEvaluator,
        CityscapesSemSegEvaluator,
        COCOEvaluator,
        COCOPanopticEvaluator,
        DatasetEvaluators,
        LVISEvaluator,
        PascalVOCDetectionEvaluator,
        SemSegEvaluator,
        verify_results,
    )
    from detectron2.modeling import GeneralizedRCNNWithTTA
    from detectron2.data.datasets import register_coco_instances
    
    
    def build_evaluator(cfg, dataset_name, output_folder=None):
        """
        Create evaluator(s) for a given dataset.
        This uses the special metadata "evaluator_type" associated with each builtin dataset.
        For your own dataset, you can simply create an evaluator manually in your
        script and do not have to worry about the hacky if-else logic here.
        """
        if output_folder is None:
            output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
        evaluator_list = []
        evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type
        if evaluator_type in ["sem_seg", "coco_panoptic_seg"]:
            evaluator_list.append(
                SemSegEvaluator(
                    dataset_name,
                    distributed=True,
                    output_dir=output_folder,
                )
            )
        if evaluator_type in ["coco", "coco_panoptic_seg"]:
            evaluator_list.append(COCOEvaluator(dataset_name, output_dir=output_folder))
        if evaluator_type == "coco_panoptic_seg":
            evaluator_list.append(COCOPanopticEvaluator(dataset_name, output_folder))
        if evaluator_type == "cityscapes_instance":
            return CityscapesInstanceEvaluator(dataset_name)
        if evaluator_type == "cityscapes_sem_seg":
            return CityscapesSemSegEvaluator(dataset_name)
        elif evaluator_type == "pascal_voc":
            return PascalVOCDetectionEvaluator(dataset_name)
        elif evaluator_type == "lvis":
            return LVISEvaluator(dataset_name, output_dir=output_folder)
        if len(evaluator_list) == 0:
            raise NotImplementedError(
                "no Evaluator for the dataset {} with the type {}".format(dataset_name, evaluator_type)
            )
        elif len(evaluator_list) == 1:
            return evaluator_list[0]
        return DatasetEvaluators(evaluator_list)
    
    
    class Trainer(DefaultTrainer):
        """
        We use the "DefaultTrainer" which contains pre-defined default logic for
        standard training workflow. They may not work for you, especially if you
        are working on a new research project. In that case you can write your
        own training loop. You can use "tools/plain_train_net.py" as an example.
        """
    
        @classmethod
        def build_evaluator(cls, cfg, dataset_name, output_folder=None):
            return build_evaluator(cfg, dataset_name, output_folder)
    
        @classmethod
        def test_with_TTA(cls, cfg, model):
            logger = logging.getLogger("detectron2.trainer")
            # In the end of training, run an evaluation with TTA
            # Only support some R-CNN models.
            logger.info("Running inference with test-time augmentation ...")
            model = GeneralizedRCNNWithTTA(cfg, model)
            evaluators = [
                cls.build_evaluator(
                    cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA")
                )
                for name in cfg.DATASETS.TEST
            ]
            res = cls.test(cfg, model, evaluators)
            res = OrderedDict({k + "_TTA": v for k, v in res.items()})
            return res
    
    
    def setup(args):
        """
        Create configs and perform basic setups.
        """
        cfg = get_cfg()
        args.config_file = "/root/autodl-tmp/detectron2/configs/COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml"
        cfg.merge_from_file(args.config_file)
        cfg.merge_from_list(args.opts)
        # 更改配置参数
        cfg.DATASETS.TRAIN = ("coco_my_train",) # 训练数据集名称
        cfg.DATASETS.TEST = ("coco_my_val",)
        cfg.DATALOADER.NUM_WORKERS = 2  # 单线程
    
        #cfg.INPUT.CROP.ENABLED = True
        #cfg.INPUT.MAX_SIZE_TRAIN = 640 # 训练图片输入的最大尺寸
        #cfg.INPUT.MAX_SIZE_TEST = 640 # 测试数据输入的最大尺寸
        #cfg.INPUT.MIN_SIZE_TRAIN = (512, 768) # 训练图片输入的最小尺寸,可以设定为多尺度训练
        #cfg.INPUT.MIN_SIZE_TEST = 640
        #cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING,其存在两种配置,分别为 choice 与 range :
        # range 让图像的短边从 512-768随机选择
        #choice : 把输入图像转化为指定的,有限的几种图片大小进行训练,即短边只能为 512或者768
        #cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING = 'range'
    #  本句一定要看下注释!!!!!!!!
        cfg.MODEL.ROI_HEADS.NUM_CLASSES = 10
        #cfg.MODEL.RETINANET.NUM_CLASSES = 10 # 类别数+1(因为有background,也就是你的 cate id 从 1 开始,如果您的数据集Json下标从 0 开始,这个改为您对应的类别就行,不用再加背景类!!!!!)
        #cfg.MODEL.WEIGHTS="/home/yourstorePath/.pth"
        cfg.MODEL.WEIGHTS = "/root/autodl-tmp/detectron2/model_final_280758.pkl"    # 预训练模型权重
        cfg.SOLVER.IMS_PER_BATCH = 8 # batch_size=2; iters_in_one_epoch = dataset_imgs/batch_size
    
        # 根据训练数据总数目以及batch_size,计算出每个epoch需要的迭代次数
        #9000为你的训练数据的总数目,可自定义
        ITERS_IN_ONE_EPOCH = int(50/ cfg.SOLVER.IMS_PER_BATCH)
    
        # 指定最大迭代次数
        cfg.SOLVER.MAX_ITER = 1500
        # 初始学习率
        cfg.SOLVER.BASE_LR = 0.002
        # 优化器动能
        cfg.SOLVER.MOMENTUM = 0.9
        #权重衰减
        cfg.SOLVER.WEIGHT_DECAY = 0.0001
        cfg.SOLVER.WEIGHT_DECAY_NORM = 0.0
        # 学习率衰减倍数
        cfg.SOLVER.GAMMA = 0.1
        # 迭代到指定次数,学习率进行衰减
        cfg.SOLVER.STEPS = (900,)
        # 在训练之前,学习率慢慢增加初始学习率
        cfg.SOLVER.WARMUP_FACTOR = 1.0 / 1000
        # 热身迭代次数
        cfg.SOLVER.WARMUP_ITERS = 500
        cfg.MODEL.DEVICE = 'cuda'
        cfg.SOLVER.WARMUP_METHOD = "linear"
        # 保存模型文件的命名数据减1
        #cfg.SOLVER.CHECKPOINT_PERIOD = ITERS_IN_ONE_EPOCH - 1
        #cfg.MODEL.DEVICE = 'cuda'
        # 迭代到指定次数,进行一次评估
        cfg.TEST.EVAL_PERIOD = ITERS_IN_ONE_EPOCH
        cfg.TEST.EVAL_PERIOD = 100
        cfg.OUTPUT_DIR = "./output_trainsample/"
        cfg.freeze()
        default_setup(cfg, args)
        return cfg
    
    
    
    def main(args):
        cfg = setup(args)
        # coco_my_train是你的数据集自定义的名字,没有要求
        # 后面就是你数据集json文件放置和图片放置路径
        # 这是训练集
        register_coco_instances("coco_my_train", {}, "/root/autodl-tmp/detectron2/mydatasets/coco/annotations/instances_train2017.json",
                                "/root/autodl-tmp/detectron2/mydatasets/coco/train2017")
        # 验证集
        register_coco_instances("coco_my_val", {}, "/root/autodl-tmp/detectron2/mydatasets/coco/annotations/instances_val2017.json",
                                "/root/autodl-tmp/detectron2/mydatasets/coco/val2017")
        if args.eval_only:
            model = Trainer.build_model(cfg)
            DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
                cfg.MODEL.WEIGHTS, resume=args.resume
            )
            res = Trainer.test(cfg, model)
            if cfg.TEST.AUG.ENABLED:
                res.update(Trainer.test_with_TTA(cfg, model))
            if comm.is_main_process():
                verify_results(cfg, res)
            return res
        """
            If you'd like to do anything fancier than the standard training logic,
            consider writing your own training loop (see plain_train_net.py) or
            subclassing the trainer.
            """
        trainer = Trainer(cfg)
        trainer.resume_or_load(resume=args.resume)
        if cfg.TEST.AUG.ENABLED:
            trainer.register_hooks(
                [hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model))]
            )
        return trainer.train()
    
    """
    这部分代码可以利用注册数据集的元数据得到你的数据字典格式和内容,可以打印和加载注释图片
        custom_metadata = MetadataCatalog.get("coco_my_train")
        dataset_dicts = DatasetCatalog.get("coco_my_train")
        print(dataset_dicts)
        for d in random.sample(dataset_dicts, 1):
            img = cv2.imread(d["file_name"])
            #print(img)
            visualizer = Visualizer(img[:, :, ::-1], metadata=custom_metadata, scale=1)
            vis = visualizer.draw_dataset_dict(d)
            cv2.imshow('Sample',vis.get_image()[:, :, ::-1])
            cv2.waitKey(1500)
    """
    
    
    
    if __name__ == "__main__":
        args = default_argument_parser().parse_args()
        print("Command Line Args:", args)
        launch(
            main,
            args.num_gpus,
            num_machines=args.num_machines,
            machine_rank=args.machine_rank,
            dist_url=args.dist_url,
            args=(args,),
        )
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
    • 168
    • 169
    • 170
    • 171
    • 172
    • 173
    • 174
    • 175
    • 176
    • 177
    • 178
    • 179
    • 180
    • 181
    • 182
    • 183
    • 184
    • 185
    • 186
    • 187
    • 188
    • 189
    • 190
    • 191
    • 192
    • 193
    • 194
    • 195
    • 196
    • 197
    • 198
    • 199
    • 200
    • 201
    • 202
    • 203
    • 204
    • 205
    • 206
    • 207
    • 208
    • 209
    • 210
    • 211
    • 212
    • 213
    • 214
    • 215
    • 216
    • 217
    • 218
    • 219
    • 220
    • 221
    • 222
    • 223
    • 224
    • 225
    • 226
    • 227
    • 228
    • 229
    • 230
    • 231
    • 232
    • 233
    • 234
    • 235
    • 236
    • 237
    • 238

    后运行

    python train_net.py
    
    • 1

    后续demo代码同方法一

    python demo/demo.py --config-file /root/autodl-tmp/detectron2/output_trainsample/config.yaml --video-input /root/autodl-tmp/detectron2/detect/1.avi  --output demo/4.avi --opts MODEL.WEIGHTS /root/autodl-tmp/detectron2/output_trainsample/model_final.pth
    
    
    • 1
    • 2

    五 问题

    1.缺少cv2模块

    pip install opencv-python
    
    • 1

    2.PIL中缺少模块

    `LINEAR` is deprecated and will be removed in Pillow 10 (2023-07-01). Use BILINEAR or Resampling.BILINEAR instead. 
    
    • 1

    版本更新

    pip install pillow==9.5.0
    
    • 1
  • 相关阅读:
    Java项目(三)-- SSM开发社交网站(8)--实现会员交互功能
    文生视频综述
    谈一谈SQLite、MySQL、PostgreSQL三大数据库
    【OpenCV】 OpenCV 源码编译并实现 CUDA 加速 (Windows)
    Windows操作系统查看ip的方式和Linux系统查看ip的方式的区别
    k8s dashboard安装部署实战详细手册
    k8s 对外服务之 Ingress
    GEO振弦式钢筋计的组装
    MySQL中datetime和timestamp的区别
    基于springboot实现音乐网站与分享平台项目【项目源码+论文说明】
  • 原文地址:https://blog.csdn.net/qq_59159431/article/details/131127847