• 基于Yolov8网络进行目标检测(二)-安装和自定义数据集


    关于Yolov8的安装在前一个环节忽略了,其实非常简单,只需要以下两个步骤:

    1、安装pytorch

    pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

    2、安装ultralytics

    pip install  ultralytics

    为什么把目录结构单独拿出来扯呢?这个和训练自己的数据集息息相关。

    首先我们要知道YOLOv8这次发行中带的预训练模型,是是基于COCO val2017 数据集训练的结果。

    Coco2017数据集是具有80个类别的大规模数据集,其数据分为三部分:训练、验证和测试,每部分分别包含 118287, 5000 和 40670张图片,总大小约25g。其中测试数据集没有标注信息,所以注释部分只有训练和验证的

    我们看一下yolo进行模型训练的方法,一种是CLI方式,一种是Python方式

    CLI方式:

    1. # Build a new model from YAML and start training from scratch
    2. yolo detect train data=coco128.yaml model=yolov8n.yaml epochs=100 imgsz=640
    3. # Start training from a pretrained *.pt model
    4. yolo detect train data=coco128.yaml model=yolov8n.pt epochs=100 imgsz=640
    5. # Build a new model from YAML, transfer pretrained weights to it and start training
    6. yolo detect train data=coco128.yaml model=yolov8n.yaml pretrained=yolov8n.pt epochs=100 imgsz=640

    Python方式:

    1. from ultralytics import YOLO
    2. # Load a model
    3. model = YOLO('yolov8n.yaml') # build a new model from YAML
    4. model = YOLO('yolov8n.pt') # load a pretrained model (recommended for training)
    5. model = YOLO('yolov8n.yaml').load('yolov8n.pt') # build from YAML and transfer weights
    6. # Train the model
    7. results = model.train(data='coco128.yaml', epochs=100, imgsz=640)

    我们以CLI方式为例

    mode: 选择是训练、验证还是预测的任务蕾西 可选['train', 'val', 'predict']

    model: 选择yolov8不同的预训练模型,可选yolov8s.pt、yolov8m.pt、yolov8l.pt、yolov8x.pt;或选择yolov8不同的模型配置文件,可选yolov8s.yaml、yolov8m.yaml、yolov8l.yaml、yolov8x.yaml

    data: 选择生成的数据集配置文件

    epochs:指的就是训练过程中整个数据集将被迭代多少次,显卡不行你就调小点。

    batch:一次看完多少张图片才进行权重更新,梯度下降的mini-batch,显卡不行你就调小点

    其中data和model要画重点,data是要自己训练的数据集配置文件。

    model一般是预训练模型,通常用yolov8n.pt、yolov8s.pt、yolov8m.pt、yolov8l.pt、yolov8x.pt就可以了,但如果想自己指定训练配置文件呢?这个时候,model就使用yolov8n.yaml等网络配置文件, 增加参数pretrained使用yolov8n.pt了。

    这些文件在哪儿呢?

    到项目所在的venv\Lib\site-packages\ultralytics目录下,看两个重要的目录cfg/datasets和cfg/models/v8

    1. \\venv\Lib\site-packages\ultralytics>
    2. ├─assets
    3. ├─cfg
    4. │ ├─datasets
    5. │ ├─models
    6. │ │ ├─rt-detr
    7. │ │ ├─v3
    8. │ │ ├─v5
    9. │ │ ├─v6
    10. │ │ └─v8

    yolov8内置了以下模型配置文件

    6ab5a6a5f3065ceba2f7582f09eb2ad1.png

    我们看一下yolov8.yaml文件,里面包含了标签总数,yolo几种不同训练模型的Layer数量、参数量、梯度量;骨干网的结构、Head的结构。

    要做的事情很简单,基于yolov8.yaml另外复制一份基于训练集命名的文件,只需要修改nc后面的标签总数即可,在训练前可以认为标签总数是已知的。

    347799271c1596e562475adaeefccb30.png

    数据集配置文件还内置Argoverse.yaml、coco-pose.yaml、coco.yaml、coco128-seg.yaml、coco128.yaml、coco8-pose.yaml、coco8-seg.yaml、coco8.yaml、data.yaml、DOTAv2.yaml、GlobalWheat2020.yaml、ImageNet.yaml、Objects365.yaml、open-images-v7.yaml、SKU-110K.yaml、VisDrone.yaml、VOC.yaml、xView.yaml等模板。

    我们看一下coco128.yaml文件,里面包含path(数据集根目录)、train(训练集图片路径))、val(验证集图片路径)、test(测试集图片路径);标签列表清单,按照序号:标签名的方式进行枚举,最后还包括了一个Download script/URL (optional)信息,即下载脚本和路径,这个是可选项 。

    8d094373d017cfdfcd6327f06e8718a8.png

    要做的事情很简单,基于coco128.yaml另外复制一份基于训练集命名VOC2012.yaml(我这里是VOC2012)的文件,只需要修改path、train、val、test路径即可;同时需要修改names下的标签列表,然后把多余的download脚本剔除掉,因为假设我们已经提前下载并标注了图片。

    3c2fe33eb8427ab1a3bd3ab167ae449e.png

    再回过头来看一下数据集的组织,在我们的项目根目录下增加一下datasets目录,然后每个目录一个文件夹,文件夹下包括images(图片文件夹)和label(标签文件夹),images放置train、val、test等图片目录,label下一般会放在train、val等标注信息。

    1. └─datasets
    2. ├─coco128
    3. │ ├─images
    4. │ │ └─train2017
    5. │ └─labels
    6. │ └─train2017
    7. └─VOC2012
    8. ├─images
    9. │ └─train
    10. └─labels
    11. └─train

    这个目录该怎么放数据呢?按照正常的做法是先下载VOC2012数据集

    VOC2012数据集包括二十个对象类别:

    Person :person

    Animal :bird, cat, cow, dog, horse, sheep

    Vehicle :aeroplane, bicycle, boat, bus, car, motorbike, train

    Indoor :bottle, chair, dining table, potted plant, sofa, tv/monitor

    VOC2012数据集的目录结构如下:

    1. └─VOCdevkit
    2. └─VOC2012
    3. ├─Annotations
    4. ├─ImageSets
    5. │ ├─Action
    6. │ ├─Layout
    7. │ ├─Main
    8. │ └─Segmentation
    9. ├─JPEGImages
    10. ├─SegmentationClass
    11. └─SegmentationObject

    其中Annotation是标注文件夹,JPEGImages是图片文件夹,基本用到这两个目录,正常情况下我们先会区分训练集、验证集和测试集,当然这次没这么做。不过可以看一下代码,后续做也可以。

    1. import os
    2. import random
    3. import argparse
    4. parser = argparse.ArgumentParser()
    5. #xml文件的地址,根据自己的数据进行修改 xml一般存放在Annotations下
    6. parser.add_argument('--xml_path', default='VOCdevkit/VOC2012/Annotations', type=str, help='input xml label path')
    7. #数据集的划分,地址选择自己数据下的ImageSets/Main
    8. parser.add_argument('--txt_path', default='VOCdevkit/VOC2012/ImageSets/Main', type=str, help='output txt label path')
    9. opt = parser.parse_args()
    10. # Namespace(xml_path='VOCdevkit/VOC2012/Annotations', txt_path='VOCdevkit/VOC2012/ImageSets/Main')
    11. # 训练+验证集一共所占的比例为0.8,剩下的0.2就是测试集
    12. # (train+val)/(train+val+test)=80%
    13. trainval_percent = 0.8
    14. # (train)/(train+val)=80%
    15. # 训练集在训练集和验证集总集合中占的比例
    16. train_percent = 0.8
    17. xmlfilepath = opt.xml_path
    18. # VOCdevkit/VOC2012/Annotations
    19. txtsavepath = opt.txt_path
    20. # VOCdevkit/dataset/ImageSets/Main
    21. # 获取标注文件数量
    22. total_xml = os.listdir(xmlfilepath)
    23. # 创建文件目录
    24. if not os.path.exists(txtsavepath):
    25. os.makedirs(txtsavepath)
    26. # 随机打散文件序号,生成trainval和train两个随机数组
    27. num = len(total_xml)
    28. list_index = range(num)
    29. tv = int(num * trainval_percent)
    30. tr = int(tv * train_percent)
    31. trainval = random.sample(list_index, tv)
    32. train = random.sample(trainval, tr)
    33. fileTrainVal = open(txtsavepath + '/trainval.txt', 'w')
    34. fileTrain = open(txtsavepath + '/train.txt', 'w')
    35. fileVal = open(txtsavepath + '/val.txt', 'w')
    36. fileTest = open(txtsavepath + '/test.txt', 'w')
    37. for i in list_index:
    38. # 获取文件名
    39. name = total_xml[i][:-4] + '\n'
    40. # 根据trainval,train,val,test的顺序依次写入相关文件
    41. if i in trainval:
    42. fileTrainVal.write(name)
    43. if i in train:
    44. fileTrain.write(name)
    45. else:
    46. fileVal.write(name)
    47. else:
    48. fileTest.write(name)
    49. fileTrainVal.close()
    50. fileTrain.close()
    51. fileVal.close()
    52. fileTest.close()

    再次是对VOC2012的标注文件XML转换为Yolo的Txt标注格式。

    注这里的classes顺序要和上面的VOC2012.yaml中的name保持一致,否则会出现标签名称不对应的情况。

    1. # -*- coding: utf-8 -*-
    2. import xml.etree.ElementTree as ET
    3. import os
    4. sets = ['train', 'val', 'test']
    5. classes = ["aeroplane", 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog',
    6. 'horse', 'motorcycle', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']
    7. absPath = os.getcwd()
    8. def convert(size, box):
    9. '''
    10. :param size: 图片size
    11. :param box: 标注框坐标
    12. :return:
    13. VOC->YOLO转换算法
    14. norm_x=(xmin + xmax)/2/width
    15. norm_y=(ymin + ymax)/2/height
    16. norm_w=(xmax - xmin)/width
    17. norm_h=(ymax - ymin)/height
    18. YOLO->VOC转换算法
    19. xmin=width * (norm_x - 0.5 * norm_w)
    20. ymin=height * (norm_y - 0.5 * norm_h)
    21. xmax=width * (norm_x + 0.5 * norm_w)
    22. ymax=height * (norm_y + 0.5 * norm_h)
    23. '''
    24. dw = 1. / (size[0])
    25. dh = 1. / (size[1])
    26. x = (box[0] + box[1]) / 2.0 - 1
    27. y = (box[2] + box[3]) / 2.0 - 1
    28. w = box[1] - box[0]
    29. h = box[3] - box[2]
    30. x = x * dw
    31. w = w * dw
    32. y = y * dh
    33. h = h * dh
    34. return x, y, w, h
    35. def ConvertAnnotation(image_id):
    36. inputFile = open(absPath + '/VOCdevkit/VOC2012/Annotations/%s.xml' % (image_id), encoding='UTF-8')
    37. outFile = open(absPath + '/VOCdevkit/VOC2012/YOLOLabels/%s.txt' % (image_id), 'w')
    38. '''
    39. VOC2012 标注格式
    40. VOC2012
    41. 2008_007069.jpg
    42. The VOC2008 Database
    43. PASCAL VOC2008
    44. flickr
    45. 500
    46. 375
    47. 3
    48. 0
    49. sheep
    50. Right
    51. 0
    52. 0
    53. 411
    54. 172
    55. 445
    56. 195
    57. 0
    58. '''
    59. '''
    60. Yolo 标注文件格式
    61. labelclass xCenter yCenter width height
    62. 每个标签有五个数据,依次代表:
    63. 所标注内容的类别,数字与类别一一对应
    64. 1、labelclass 标注框类别 labelclass
    65. 2、xCenter 归一化后标注框的中心点的x轴
    66. 3、yCenter 归一化后标注框的中心点的y轴
    67. 4、width 归一化后目标框的宽度
    68. 5、height 归一化后目标框的高度
    69. '''
    70. tree = ET.parse(inputFile)
    71. root = tree.getroot()
    72. # 获取标注图片的大小
    73. size = root.find('size')
    74. width = int(size.find('width').text)
    75. height = int(size.find('height').text)
    76. # 获取标注框信息
    77. for obj in root.iter('object'):
    78. difficult = obj.find('difficult').text
    79. # 获取标注类别名称
    80. cls = obj.find('name').text
    81. if cls not in classes or int(difficult) == 1:
    82. continue
    83. # 将标注类别按照classes列表信息转换为索引ID
    84. clsId = classes.index(cls)
    85. # 获取标注框信息
    86. xmlBox = obj.find('bndbox')
    87. boundry = (float(xmlBox.find('xmin').text), float(xmlBox.find('xmax').text), float(xmlBox.find('ymin').text),
    88. float(xmlBox.find('ymax').text))
    89. xmin, xmax, ymin, ymax = boundry
    90. # 标注越界修正
    91. if xmax > width:
    92. xmax = width
    93. if ymax > height:
    94. ymax = height
    95. box = (xmin, xmax, ymin, ymax)
    96. transBox = convert((width, height), box)
    97. outFile.write(str(clsId) + " " + " ".join([str(a) for a in transBox]) + '\n')
    98. # 判断标注转换目录是否存在
    99. if not os.path.exists(absPath + '/VOCdevkit/VOC2012/YOLOLabels/'):
    100. os.makedirs(absPath + '/VOCdevkit/VOC2012/YOLOLabels/')
    101. for imageSet in sets:
    102. # 获取当前文件(train/val/test)的图片ID
    103. imageIds = open(absPath + '/VOCdevkit/VOC2012/ImageSets/Main/%s.txt' % (imageSet)).read().strip().split()
    104. listFile = open(absPath + '/VOCdevkit/VOC2012/%s.txt' % (imageSet), 'w')
    105. for imageId in imageIds:
    106. # 遍历文件名列表,分别将图片文件全路径写入新的文件中
    107. listFile.write(absPath + '/VOCdevkit/VOC2012/JPEGImages/%s.jpg\n' % (imageId))
    108. # 进行文件格式转换
    109. ConvertAnnotation(imageId)
    110. listFile.close()

    关于Yolov8训练自己的数据集的前序准备工作已完成,后续讲一下怎么开展训练过程。

    最后欢迎关注公众号:python与大数据分析

    ae8a568a141ae3280babe1311bb4ae31.jpeg

  • 相关阅读:
    《数字图像处理-OpenCV/Python》连载(41)图像的旋转
    SQL User-Agent注入详解
    【Python 零基础入门】Pandas
    精通Nginx(08)-反向代理
    python输出唐诗 青少年编程电子学会python编程等级考试二级真题解析2021年12月
    Spring IoC容器简介说明(BeanFactory和ApplicationContext)
    java设计模式---建造者模式
    【JVM】jvm的体系结构
    PHP基础面试题
    MySQL优化01-索引
  • 原文地址:https://blog.csdn.net/baoqiangwang/article/details/132893528