• yolo系列模型训练数据集全流程制作方法(附数据增强代码)


    yolo系列的模型在目标检测领域里面受众非常广,也十分流行,但是在使用yolo进行目标检测训练的时候,往往要将VOC格式的数据集转化为yolo专属的数据集,而yolo的训练数据集制作方法呢,最常见的也是有两种,下面我们只讲述一种最常用的方法,也是我最常使用的。

    1. voc转yolo格式

    我最常使用的目标检测数据集为VOC格式,而它的格式一般如下所示:

    - dataset
      |- annotations
      |  |- image1.xml
      |  |- image2.xml
      |  |- ...
      |
      |- images
      |  |- image1.jpg
      |  |- image2.jpg
      |  |- ...
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • dataset 是数据集的根目录。
    • annotations 目录包含每个图像对应的 XML 注释文件。
    • images 目录包含每个图像文件。

    而我们要转换的yolo格式如下所示:

    - dataset
      |- images
      |  |- image1.jpg
      |  |- image2.jpg
      |  |- ...
      |
      |- labels
      |  |- image1.txt
      |  |- image2.txt
      |  |- ...
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • dataset 是数据集的根目录。
    • images 目录包含每个图像文件,通常是以 .jpg 或 .png 等格式保存的图像文件。
    • labels 目录包含每个图像对应的标签文件,通常是以 .txt 格式保存的文本文件。

    而 labels 里面的内容填写格式为下图所示:
    在这里插入图片描述

    通常,每行的格式为:class x_center y_center width height,其中class代表的是图片中目标所对应的类别,x_center, y_center是边界框的中心点坐标相对于图像宽度和高度的归一化值,widthheight 是边界框的宽度和高度相对于图像宽度和高度的归一化值。
    举例如下:
    在这里插入图片描述

    转换代码:

    import xml.etree.ElementTree as ET
    import pickle
    import os
    from os import listdir, getcwd
    from os.path import join
    
    
    def convert(size, box):
        x_center = (box[0] + box[1]) / 2.0
        y_center = (box[2] + box[3]) / 2.0
        x = x_center / size[0]
        y = y_center / size[1]
    
        w = (box[1] - box[0]) / size[0]
        h = (box[3] - box[2]) / size[1]
    
        return (x, y, w, h)
    
    
    def convert_annotation(xml_files_path, save_txt_files_path, classes):
        xml_files = os.listdir(xml_files_path)
        for xml_name in xml_files:
            xml_file = os.path.join(xml_files_path, xml_name)
            out_txt_path = os.path.join(save_txt_files_path, xml_name.split('.')[0] + '.txt')
            out_txt_f = open(out_txt_path, 'w')
            tree = ET.parse(xml_file)
            root = tree.getroot()
            size = root.find('size')
            w = int(size.find('width').text)
            h = int(size.find('height').text)
    
            for obj in root.iter('object'):
                #difficult = obj.find('difficult').text
                cls = obj.find('name').text
                #if cls not in classes or int(difficult) == 1:
                    #continue
                cls_id = classes.index(cls)
                xmlbox = obj.find('bndbox')
                b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text),
                     float(xmlbox.find('ymax').text))
                # b=(xmin, xmax, ymin, ymax)
                # print(w, h, b)
                bb = convert((w, h), b)
                out_txt_f.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n')
    
    
    if __name__ == "__main__":
        # 把forklift_pallet的voc的xml标签文件转化为yolo的txt标签文件
        # 1、需要转化的类别,这里我直接用数字代表类别,由于我是八类,所以从0到7
        classes = ['0', '1', '2', '3', '4', '5', '6', '7']
        # 2、voc格式的xml标签文件路径
        xml_files1 = 'annotations'
        # 3、转化为yolo格式的txt标签文件存储路径
        save_txt_files1 = 'labels'
    
        convert_annotation(xml_files1, save_txt_files1, classes)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56

    上面代码中注释了一部分内容,比如difficult这一项,由于我xml文件里面没有difficult,所以就注释掉了,大家按照自己的需求进行使用即可。

    划分数据集

    在我们进行yolo目标检测模型训练之前,需要先将数据集进行合理的划分,比如说划分为训练集:验证集=8:2,或者训练集:验证集:测试集=7:2:1。不过我一般习惯只划分训练集和验证集,也就是按8:2的比例进行划分,代码如下所示:

    import os
    import shutil
    import random
    
    # 定义数据集文件夹路径
    dataset_path = 'dataset'
    images_path = os.path.join(dataset_path, 'images')
    labels_path = os.path.join(dataset_path, 'labels')
    
    # 定义划分后的文件夹路径
    new_path = 'mydata'
    train_path = os.path.join(new_path, 'train')
    val_path = os.path.join(new_path, 'val')
    
    # 创建train和val文件夹
    os.makedirs(os.path.join(train_path, 'images'), exist_ok=True)
    os.makedirs(os.path.join(train_path, 'labels'), exist_ok=True)
    os.makedirs(os.path.join(val_path, 'images'), exist_ok=True)
    os.makedirs(os.path.join(val_path, 'labels'), exist_ok=True)
    
    # 获取所有图片文件的文件名
    image_files = os.listdir(images_path)
    # 随机打乱文件顺序
    random.shuffle(image_files)
    
    # 定义验证集所占比例
    val_split = 0.1
    # 计算验证集大小
    num_val = int(len(image_files) * val_split)
    
    # 将数据集按照比例划分到train和val文件夹中
    for i, image_file in enumerate(image_files):
        src_image = os.path.join(images_path, image_file)
        src_label = os.path.join(labels_path, image_file.replace('.jpg', '.txt'))
        if i < num_val:
            dst_image = os.path.join(val_path, 'images', image_file)
            dst_label = os.path.join(val_path, 'labels', image_file.replace('.jpg', '.txt'))
        else:
            dst_image = os.path.join(train_path, 'images', image_file)
            dst_label = os.path.join(train_path, 'labels', image_file.replace('.jpg', '.txt'))
        shutil.copy(src_image, dst_image)
        shutil.copy(src_label, dst_label)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42

    划分完成以后的文件夹格式为:

    - mydata
      |- train
      |  |- images
      |  |- labels
      |
      |- val
      |  |- images
      |  |- labels
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    images和labels分别是对应的数据集图片和txt标签。

    数据增强

    在我们参加一些目标检测类比赛的时候,往往会遇见比赛训练集不足的情况,这将极大程度上影响我们的模型精度,这时候可能就需要用到一些数据增强方法,如翻转、随机裁剪等等。当然,yolo系列的模型一般都自带有数据增强,但是我们也可以尝试训练前进行增强看看效果。
    代码如下:

    # -*- coding=utf-8 -*-
    
    import time
    import random
    import copy
    import cv2
    import os
    import math
    import numpy as np
    from skimage.util import random_noise
    from lxml import etree, objectify
    import xml.etree.ElementTree as ET
    import argparse
    
    
    # 显示图片
    def show_pic(img, bboxes=None):
        '''
        输入:
            img:图像array
            bboxes:图像的所有boudning box list, 格式为[[x_min, y_min, x_max, y_max]....]
            names:每个box对应的名称
        '''
        for i in range(len(bboxes)):
            bbox = bboxes[i]
            x_min = bbox[0]
            y_min = bbox[1]
            x_max = bbox[2]
            y_max = bbox[3]
            cv2.rectangle(img, (int(x_min), int(y_min)), (int(x_max), int(y_max)), (0, 255, 0), 3)
        cv2.namedWindow('pic', 0)  # 1表示原图
        cv2.moveWindow('pic', 0, 0)
        cv2.resizeWindow('pic', 1200, 800)  # 可视化的图片大小
        cv2.imshow('pic', img)
        cv2.waitKey(0)
        cv2.destroyAllWindows()
    
    
    # 图像均为cv2读取
    class DataAugmentForObjectDetection():
        def __init__(self, rotation_rate=0.5, max_rotation_angle=5,
                     crop_rate=0.5, shift_rate=0.5, change_light_rate=0.5,
                     add_noise_rate=0.5, flip_rate=0.5,
                     cutout_rate=0.5, cut_out_length=50, cut_out_holes=1, cut_out_threshold=0.5,
                     is_addNoise=True, is_changeLight=True, is_cutout=True, is_rotate_img_bbox=True,
                     is_crop_img_bboxes=True, is_shift_pic_bboxes=True, is_filp_pic_bboxes=True):
    
            # 配置各个操作的属性
            self.rotation_rate = rotation_rate
            self.max_rotation_angle = max_rotation_angle
            self.crop_rate = crop_rate
            self.shift_rate = shift_rate
            self.change_light_rate = change_light_rate
            self.add_noise_rate = add_noise_rate
            self.flip_rate = flip_rate
            self.cutout_rate = cutout_rate
    
            self.cut_out_length = cut_out_length
            self.cut_out_holes = cut_out_holes
            self.cut_out_threshold = cut_out_threshold
    
            # 是否使用某种增强方式
            self.is_addNoise = is_addNoise
            self.is_changeLight = is_changeLight
            self.is_cutout = is_cutout
            self.is_rotate_img_bbox = is_rotate_img_bbox
            self.is_crop_img_bboxes = is_crop_img_bboxes
            self.is_shift_pic_bboxes = is_shift_pic_bboxes
            self.is_filp_pic_bboxes = is_filp_pic_bboxes
    
        # ----1.加噪声---- #
        def _addNoise(self, img):
            '''
            输入:
                img:图像array
            输出:
                加噪声后的图像array,由于输出的像素是在[0,1]之间,所以得乘以255
            '''
            # return cv2.GaussianBlur(img, (11, 11), 0)
            return random_noise(img, mode='gaussian', seed=int(time.time()), clip=True) * 255
    
        # ---2.调整亮度--- #
        def _changeLight(self, img):
            alpha = random.uniform(0.35, 1)
            blank = np.zeros(img.shape, img.dtype)
            return cv2.addWeighted(img, alpha, blank, 1 - alpha, 0)
    
        # ---3.cutout--- #
        def _cutout(self, img, bboxes, length=100, n_holes=1, threshold=0.5):
            '''
            原版本:https://github.com/uoguelph-mlrg/Cutout/blob/master/util/cutout.py
            Randomly mask out one or more patches from an image.
            Args:
                img : a 3D numpy array,(h,w,c)
                bboxes : 框的坐标
                n_holes (int): Number of patches to cut out of each image.
                length (int): The length (in pixels) of each square patch.
            '''
    
            def cal_iou(boxA, boxB):
                '''
                boxA, boxB为两个框,返回iou
                boxB为bouding box
                '''
                # determine the (x, y)-coordinates of the intersection rectangle
                xA = max(boxA[0], boxB[0])
                yA = max(boxA[1], boxB[1])
                xB = min(boxA[2], boxB[2])
                yB = min(boxA[3], boxB[3])
    
                if xB <= xA or yB <= yA:
                    return 0.0
    
                # compute the area of intersection rectangle
                interArea = (xB - xA + 1) * (yB - yA + 1)
    
                # compute the area of both the prediction and ground-truth
                # rectangles
                boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
                boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)
                iou = interArea / float(boxBArea)
                return iou
    
            # 得到h和w
            if img.ndim == 3:
                h, w, c = img.shape
            else:
                _, h, w, c = img.shape
            mask = np.ones((h, w, c), np.float32)
            for n in range(n_holes):
                chongdie = True  # 看切割的区域是否与box重叠太多
                while chongdie:
                    y = np.random.randint(h)
                    x = np.random.randint(w)
    
                    y1 = np.clip(y - length // 2, 0,
                                 h)  # numpy.clip(a, a_min, a_max, out=None), clip这个函数将将数组中的元素限制在a_min, a_max之间,大于a_max的就使得它等于 a_max,小于a_min,的就使得它等于a_min
                    y2 = np.clip(y + length // 2, 0, h)
                    x1 = np.clip(x - length // 2, 0, w)
                    x2 = np.clip(x + length // 2, 0, w)
    
                    chongdie = False
                    for box in bboxes:
                        if cal_iou([x1, y1, x2, y2], box) > threshold:
                            chongdie = True
                            break
                mask[y1: y2, x1: x2, :] = 0.
            img = img * mask
            return img
    
        # ---4.旋转--- #
        def _rotate_img_bbox(self, img, bboxes, angle=5, scale=1.):
            '''
            参考:https://blog.csdn.net/u014540717/article/details/53301195crop_rate
            输入:
                img:图像array,(h,w,c)
                bboxes:该图像包含的所有boundingboxs,一个list,每个元素为[x_min, y_min, x_max, y_max],要确保是数值
                angle:旋转角度
                scale:默认1
            输出:
                rot_img:旋转后的图像array
                rot_bboxes:旋转后的boundingbox坐标list
            '''
            # 旋转图像
            w = img.shape[1]
            h = img.shape[0]
            # 角度变弧度
            rangle = np.deg2rad(angle)  # angle in radians
            # now calculate new image width and height
            nw = (abs(np.sin(rangle) * h) + abs(np.cos(rangle) * w)) * scale
            nh = (abs(np.cos(rangle) * h) + abs(np.sin(rangle) * w)) * scale
            # ask OpenCV for the rotation matrix
            rot_mat = cv2.getRotationMatrix2D((nw * 0.5, nh * 0.5), angle, scale)
            # calculate the move from the old center to the new center combined
            # with the rotation
            rot_move = np.dot(rot_mat, np.array([(nw - w) * 0.5, (nh - h) * 0.5, 0]))
            # the move only affects the translation, so update the translation
            rot_mat[0, 2] += rot_move[0]
            rot_mat[1, 2] += rot_move[1]
            # 仿射变换
            rot_img = cv2.warpAffine(img, rot_mat, (int(math.ceil(nw)), int(math.ceil(nh))), flags=cv2.INTER_LANCZOS4)
    
            # 矫正bbox坐标
            # rot_mat是最终的旋转矩阵
            # 获取原始bbox的四个中点,然后将这四个点转换到旋转后的坐标系下
            rot_bboxes = list()
            for bbox in bboxes:
                xmin = bbox[0]
                ymin = bbox[1]
                xmax = bbox[2]
                ymax = bbox[3]
                point1 = np.dot(rot_mat, np.array([(xmin + xmax) / 2, ymin, 1]))
                point2 = np.dot(rot_mat, np.array([xmax, (ymin + ymax) / 2, 1]))
                point3 = np.dot(rot_mat, np.array([(xmin + xmax) / 2, ymax, 1]))
                point4 = np.dot(rot_mat, np.array([xmin, (ymin + ymax) / 2, 1]))
                # 合并np.array
                concat = np.vstack((point1, point2, point3, point4))
                # 改变array类型
                concat = concat.astype(np.int32)
                # 得到旋转后的坐标
                rx, ry, rw, rh = cv2.boundingRect(concat)
                rx_min = rx
                ry_min = ry
                rx_max = rx + rw
                ry_max = ry + rh
                # 加入list中
                rot_bboxes.append([rx_min, ry_min, rx_max, ry_max])
    
            return rot_img, rot_bboxes
    
        # ---5.裁剪--- #
        def _crop_img_bboxes(self, img, bboxes):
            '''
            裁剪后的图片要包含所有的框
            输入:
                img:图像array
                bboxes:该图像包含的所有boundingboxs,一个list,每个元素为[x_min, y_min, x_max, y_max],要确保是数值
            输出:
                crop_img:裁剪后的图像array
                crop_bboxes:裁剪后的bounding box的坐标list
            '''
            # 裁剪图像
            w = img.shape[1]
            h = img.shape[0]
            x_min = w  # 裁剪后的包含所有目标框的最小的框
            x_max = 0
            y_min = h
            y_max = 0
            for bbox in bboxes:
                x_min = min(x_min, bbox[0])
                y_min = min(y_min, bbox[1])
                x_max = max(x_max, bbox[2])
                y_max = max(y_max, bbox[3])
    
            d_to_left = x_min  # 包含所有目标框的最小框到左边的距离
            d_to_right = w - x_max  # 包含所有目标框的最小框到右边的距离
            d_to_top = y_min  # 包含所有目标框的最小框到顶端的距离
            d_to_bottom = h - y_max  # 包含所有目标框的最小框到底部的距离
    
            # 随机扩展这个最小框
            crop_x_min = int(x_min - random.uniform(0, d_to_left))
            crop_y_min = int(y_min - random.uniform(0, d_to_top))
            crop_x_max = int(x_max + random.uniform(0, d_to_right))
            crop_y_max = int(y_max + random.uniform(0, d_to_bottom))
    
            # 随机扩展这个最小框 , 防止别裁的太小
            # crop_x_min = int(x_min - random.uniform(d_to_left//2, d_to_left))
            # crop_y_min = int(y_min - random.uniform(d_to_top//2, d_to_top))
            # crop_x_max = int(x_max + random.uniform(d_to_right//2, d_to_right))
            # crop_y_max = int(y_max + random.uniform(d_to_bottom//2, d_to_bottom))
    
            # 确保不要越界
            crop_x_min = max(0, crop_x_min)
            crop_y_min = max(0, crop_y_min)
            crop_x_max = min(w, crop_x_max)
            crop_y_max = min(h, crop_y_max)
    
            crop_img = img[crop_y_min:crop_y_max, crop_x_min:crop_x_max]
    
            # 裁剪boundingbox
            # 裁剪后的boundingbox坐标计算
            crop_bboxes = list()
            for bbox in bboxes:
                crop_bboxes.append([bbox[0] - crop_x_min, bbox[1] - crop_y_min, bbox[2] - crop_x_min, bbox[3] - crop_y_min])
    
            return crop_img, crop_bboxes
    
        # ---6.平移--- #
        def _shift_pic_bboxes(self, img, bboxes):
            '''
            平移后的图片要包含所有的框
            输入:
                img:图像array
                bboxes:该图像包含的所有boundingboxs,一个list,每个元素为[x_min, y_min, x_max, y_max],要确保是数值
            输出:
                shift_img:平移后的图像array
                shift_bboxes:平移后的bounding box的坐标list
            '''
            # 平移图像
            w = img.shape[1]
            h = img.shape[0]
            x_min = w  # 裁剪后的包含所有目标框的最小的框
            x_max = 0
            y_min = h
            y_max = 0
            for bbox in bboxes:
                x_min = min(x_min, bbox[0])
                y_min = min(y_min, bbox[1])
                x_max = max(x_max, bbox[2])
                y_max = max(y_max, bbox[3])
    
            d_to_left = x_min  # 包含所有目标框的最大左移动距离
            d_to_right = w - x_max  # 包含所有目标框的最大右移动距离
            d_to_top = y_min  # 包含所有目标框的最大上移动距离
            d_to_bottom = h - y_max  # 包含所有目标框的最大下移动距离
    
            x = random.uniform(-(d_to_left - 1) / 3, (d_to_right - 1) / 3)
            y = random.uniform(-(d_to_top - 1) / 3, (d_to_bottom - 1) / 3)
    
            M = np.float32([[1, 0, x], [0, 1, y]])  # x为向左或右移动的像素值,正为向右负为向左; y为向上或者向下移动的像素值,正为向下负为向上
            shift_img = cv2.warpAffine(img, M, (img.shape[1], img.shape[0]))
    
            #  平移boundingbox
            shift_bboxes = list()
            for bbox in bboxes:
                shift_bboxes.append([bbox[0] + x, bbox[1] + y, bbox[2] + x, bbox[3] + y])
    
            return shift_img, shift_bboxes
    
        # ---7.镜像--- #
        def _filp_pic_bboxes(self, img, bboxes):
            '''
                平移后的图片要包含所有的框
                输入:
                    img:图像array
                    bboxes:该图像包含的所有boundingboxs,一个list,每个元素为[x_min, y_min, x_max, y_max],要确保是数值
                输出:
                    flip_img:平移后的图像array
                    flip_bboxes:平移后的bounding box的坐标list
            '''
            # 翻转图像
    
            flip_img = copy.deepcopy(img)
            h, w, _ = img.shape
    
            sed = random.random()
    
            if 0 < sed < 0.33:  # 0.33的概率水平翻转,0.33的概率垂直翻转,0.33是对角反转
                flip_img = cv2.flip(flip_img, 0)  # _flip_x
                inver = 0
            elif 0.33 < sed < 0.66:
                flip_img = cv2.flip(flip_img, 1)  # _flip_y
                inver = 1
            else:
                flip_img = cv2.flip(flip_img, -1)  # flip_x_y
                inver = -1
    
            # 调整boundingbox
            flip_bboxes = list()
            for box in bboxes:
                x_min = box[0]
                y_min = box[1]
                x_max = box[2]
                y_max = box[3]
    
                if inver == 0:
                    # 0:垂直翻转
                    flip_bboxes.append([x_min, h - y_max, x_max, h - y_min])
                elif inver == 1:
                    # 1:水平翻转
                    flip_bboxes.append([w - x_max, y_min, w - x_min, y_max])
                elif inver == -1:
                    # -1:水平垂直翻转
                    flip_bboxes.append([w - x_max, h - y_max, w - x_min, h - y_min])
            return flip_img, flip_bboxes
    
        # 图像增强方法
        def dataAugment(self, img, bboxes):
            '''
            图像增强
            输入:
                img:图像array
                bboxes:该图像的所有框坐标
            输出:
                img:增强后的图像
                bboxes:增强后图片对应的box
            '''
            change_num = 0  # 改变的次数
            # print('------')
            while change_num < 1:  # 默认至少有一种数据增强生效
    
                if self.is_rotate_img_bbox:
                    if random.random() > self.rotation_rate:  # 旋转
                        change_num += 1
                        angle = random.uniform(-self.max_rotation_angle, self.max_rotation_angle)
                        scale = random.uniform(0.7, 0.8)
                        img, bboxes = self._rotate_img_bbox(img, bboxes, angle, scale)
    
                if self.is_shift_pic_bboxes:
                    if random.random() < self.shift_rate:  # 平移
                        change_num += 1
                        img, bboxes = self._shift_pic_bboxes(img, bboxes)
    
                if self.is_changeLight:
                    if random.random() > self.change_light_rate:  # 改变亮度
                        change_num += 1
                        img = self._changeLight(img)
    
                if self.is_addNoise:
                    if random.random() < self.add_noise_rate:  # 加噪声
                        change_num += 1
                        img = self._addNoise(img)
                if self.is_cutout:
                    if random.random() < self.cutout_rate:  # cutout
                        change_num += 1
                        img = self._cutout(img, bboxes, length=self.cut_out_length, n_holes=self.cut_out_holes,
                                           threshold=self.cut_out_threshold)
                if self.is_filp_pic_bboxes:
                    if random.random() < self.flip_rate:  # 翻转
                        change_num += 1
                        img, bboxes = self._filp_pic_bboxes(img, bboxes)
    
            return img, bboxes
    
    
    # xml解析工具
    class ToolHelper():
        # 从xml文件中提取bounding box信息, 格式为[[x_min, y_min, x_max, y_max, name]]
        def parse_xml(self, path):
            '''
            输入:
                xml_path: xml的文件路径
            输出:
                从xml文件中提取bounding box信息, 格式为[[x_min, y_min, x_max, y_max, name]]
            '''
            tree = ET.parse(path)
            root = tree.getroot()
            objs = root.findall('object')
            coords = list()
            for ix, obj in enumerate(objs):
                name = obj.find('name').text
                box = obj.find('bndbox')
                x_min = int(box[0].text)
                y_min = int(box[1].text)
                x_max = int(box[2].text)
                y_max = int(box[3].text)
                coords.append([x_min, y_min, x_max, y_max, name])
            return coords
    
        # 保存图片结果
        def save_img(self, file_name, save_folder, img):
            cv2.imwrite(os.path.join(save_folder, file_name), img)
    
        # 保持xml结果
        def save_xml(self, file_name, save_folder, img_info, height, width, channel, bboxs_info):
            '''
            :param file_name:文件名
            :param save_folder:#保存的xml文件的结果
            :param height:图片的信息
            :param width:图片的宽度
            :param channel:通道
            :return:
            '''
            folder_name, img_name = img_info  # 得到图片的信息
    
            E = objectify.ElementMaker(annotate=False)
    
            anno_tree = E.annotation(
                E.folder(folder_name),
                E.filename(img_name),
                E.path(os.path.join(folder_name, img_name)),
                E.source(
                    E.database('Unknown'),
                ),
                E.size(
                    E.width(width),
                    E.height(height),
                    E.depth(channel)
                ),
                E.segmented(0),
            )
    
            labels, bboxs = bboxs_info  # 得到边框和标签信息
            for label, box in zip(labels, bboxs):
                anno_tree.append(
                    E.object(
                        E.name(label),
                        E.pose('Unspecified'),
                        E.truncated('0'),
                        E.difficult('0'),
                        E.bndbox(
                            E.xmin(box[0]),
                            E.ymin(box[1]),
                            E.xmax(box[2]),
                            E.ymax(box[3])
                        )
                    ))
    
            etree.ElementTree(anno_tree).write(os.path.join(save_folder, file_name), pretty_print=True)
    
    
    if __name__ == '__main__':
    
        need_aug_num = 5  # 每张图片需要增强的次数
    
        is_endwidth_dot = True  # 文件是否以.jpg或者png结尾
    
        dataAug = DataAugmentForObjectDetection()  # 数据增强工具类
    
        toolhelper = ToolHelper()  # 工具
    
        # 获取相关参数
        parser = argparse.ArgumentParser()
        parser.add_argument('--source_img_path', type=str, default='images')
        parser.add_argument('--source_xml_path', type=str, default='Annotations')
        parser.add_argument('--save_img_path', type=str, default='enhance_images')
        parser.add_argument('--save_xml_path', type=str, default='enhance_Annotations')
        args = parser.parse_args()
        source_img_path = args.source_img_path  # 图片原始位置
        source_xml_path = args.source_xml_path  # xml的原始位置
    
        save_img_path = args.save_img_path  # 图片增强结果保存文件
        save_xml_path = args.save_xml_path  # xml增强结果保存文件
    
        # 如果保存文件夹不存在就创建
        if not os.path.exists(save_img_path):
            os.mkdir(save_img_path)
    
        if not os.path.exists(save_xml_path):
            os.mkdir(save_xml_path)
    
        for parent, _, files in os.walk(source_img_path):
            files.sort()
            for file in files:
                cnt = 0
                pic_path = os.path.join(parent, file)
                xml_path = os.path.join(source_xml_path, file[:-4] + '.xml')
                values = toolhelper.parse_xml(xml_path)  # 解析得到box信息,格式为[[x_min,y_min,x_max,y_max,name]]
                coords = [v[:4] for v in values]  # 得到框
                labels = [v[-1] for v in values]  # 对象的标签
    
                # 如果图片是有后缀的
                if is_endwidth_dot:
                    # 找到文件的最后名字
                    dot_index = file.rfind('.')
                    _file_prefix = file[:dot_index]  # 文件名的前缀
                    _file_suffix = file[dot_index:]  # 文件名的后缀
                img = cv2.imread(pic_path)
    
                # show_pic(img, coords)  # 显示原图
                while cnt < need_aug_num:  # 继续增强
                    auged_img, auged_bboxes = dataAug.dataAugment(img, coords)
                    auged_bboxes_int = np.array(auged_bboxes).astype(np.int32)
                    height, width, channel = auged_img.shape  # 得到图片的属性
                    img_name = '{}_{}{}'.format(_file_prefix, cnt + 1, _file_suffix)  # 图片保存的信息
                    toolhelper.save_img(img_name, save_img_path,
                                        auged_img)  # 保存增强图片
    
                    toolhelper.save_xml('{}_{}.xml'.format(_file_prefix, cnt + 1),
                                        save_xml_path, (save_img_path, img_name), height, width, channel,
                                        (labels, auged_bboxes_int))  # 保存xml文件
                    # show_pic(auged_img, auged_bboxes)  # 强化后的图
                    print(img_name)
                    cnt += 1  # 继续增强下一张
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
    • 168
    • 169
    • 170
    • 171
    • 172
    • 173
    • 174
    • 175
    • 176
    • 177
    • 178
    • 179
    • 180
    • 181
    • 182
    • 183
    • 184
    • 185
    • 186
    • 187
    • 188
    • 189
    • 190
    • 191
    • 192
    • 193
    • 194
    • 195
    • 196
    • 197
    • 198
    • 199
    • 200
    • 201
    • 202
    • 203
    • 204
    • 205
    • 206
    • 207
    • 208
    • 209
    • 210
    • 211
    • 212
    • 213
    • 214
    • 215
    • 216
    • 217
    • 218
    • 219
    • 220
    • 221
    • 222
    • 223
    • 224
    • 225
    • 226
    • 227
    • 228
    • 229
    • 230
    • 231
    • 232
    • 233
    • 234
    • 235
    • 236
    • 237
    • 238
    • 239
    • 240
    • 241
    • 242
    • 243
    • 244
    • 245
    • 246
    • 247
    • 248
    • 249
    • 250
    • 251
    • 252
    • 253
    • 254
    • 255
    • 256
    • 257
    • 258
    • 259
    • 260
    • 261
    • 262
    • 263
    • 264
    • 265
    • 266
    • 267
    • 268
    • 269
    • 270
    • 271
    • 272
    • 273
    • 274
    • 275
    • 276
    • 277
    • 278
    • 279
    • 280
    • 281
    • 282
    • 283
    • 284
    • 285
    • 286
    • 287
    • 288
    • 289
    • 290
    • 291
    • 292
    • 293
    • 294
    • 295
    • 296
    • 297
    • 298
    • 299
    • 300
    • 301
    • 302
    • 303
    • 304
    • 305
    • 306
    • 307
    • 308
    • 309
    • 310
    • 311
    • 312
    • 313
    • 314
    • 315
    • 316
    • 317
    • 318
    • 319
    • 320
    • 321
    • 322
    • 323
    • 324
    • 325
    • 326
    • 327
    • 328
    • 329
    • 330
    • 331
    • 332
    • 333
    • 334
    • 335
    • 336
    • 337
    • 338
    • 339
    • 340
    • 341
    • 342
    • 343
    • 344
    • 345
    • 346
    • 347
    • 348
    • 349
    • 350
    • 351
    • 352
    • 353
    • 354
    • 355
    • 356
    • 357
    • 358
    • 359
    • 360
    • 361
    • 362
    • 363
    • 364
    • 365
    • 366
    • 367
    • 368
    • 369
    • 370
    • 371
    • 372
    • 373
    • 374
    • 375
    • 376
    • 377
    • 378
    • 379
    • 380
    • 381
    • 382
    • 383
    • 384
    • 385
    • 386
    • 387
    • 388
    • 389
    • 390
    • 391
    • 392
    • 393
    • 394
    • 395
    • 396
    • 397
    • 398
    • 399
    • 400
    • 401
    • 402
    • 403
    • 404
    • 405
    • 406
    • 407
    • 408
    • 409
    • 410
    • 411
    • 412
    • 413
    • 414
    • 415
    • 416
    • 417
    • 418
    • 419
    • 420
    • 421
    • 422
    • 423
    • 424
    • 425
    • 426
    • 427
    • 428
    • 429
    • 430
    • 431
    • 432
    • 433
    • 434
    • 435
    • 436
    • 437
    • 438
    • 439
    • 440
    • 441
    • 442
    • 443
    • 444
    • 445
    • 446
    • 447
    • 448
    • 449
    • 450
    • 451
    • 452
    • 453
    • 454
    • 455
    • 456
    • 457
    • 458
    • 459
    • 460
    • 461
    • 462
    • 463
    • 464
    • 465
    • 466
    • 467
    • 468
    • 469
    • 470
    • 471
    • 472
    • 473
    • 474
    • 475
    • 476
    • 477
    • 478
    • 479
    • 480
    • 481
    • 482
    • 483
    • 484
    • 485
    • 486
    • 487
    • 488
    • 489
    • 490
    • 491
    • 492
    • 493
    • 494
    • 495
    • 496
    • 497
    • 498
    • 499
    • 500
    • 501
    • 502
    • 503
    • 504
    • 505
    • 506
    • 507
    • 508
    • 509
    • 510
    • 511
    • 512
    • 513
    • 514
    • 515
    • 516
    • 517
    • 518
    • 519
    • 520
    • 521
    • 522
    • 523
    • 524
    • 525
    • 526
    • 527
    • 528
    • 529
    • 530
    • 531
    • 532
    • 533
    • 534
    • 535
    • 536
    • 537
    • 538
    • 539
    • 540
    • 541
    • 542
    • 543
    • 544

    增强后的效果图如下所示:
    在这里插入图片描述
    详细实战使用操作请看:基于yolov8的车牌检测训练全流程

  • 相关阅读:
    常见web安全及防护原理
    JS基础之实现数组reduce方法
    db_ha执行ha_isready报错authentication method 13 not supported
    理论和实践详解RabbitMQ死信(dead lettering)(带测试样例和分析)
    three.js 基础01
    什么是私有云?您应该知道的 6 个优势
    python 对象列表list转字典dict
    Clion新增子模块库代码跳转
    Bigkey问题的解决思路与方式探索
    C++08函数模板
  • 原文地址:https://blog.csdn.net/m0_63007797/article/details/134517280