• 【目标检测】YOLO+DOTA:小样本检测策略


    前言

    之前在使用YOLOv5跑xView数据集时,发现准确率还是非常低的。在网上冲浪时,我发现了一种小样本检测策略:那就是把大分辨率的图片分割成小块进行训练,然后再输入大图进行检测。那么本篇博文就使用DOTA数据集来验证一下这种思路是否可行。

    主要参考的项目:https://github.com/postor/DOTA-yolov3

    DOTA数据集简介

    DOTA数据集全称:Dataset for Object deTection in Aerial images
    DOTA数据集v1.0共收录2806张4000 × 4000的图片,总共包含188282个目标。

    在这里插入图片描述

    DOTA数据集论文介绍:https://arxiv.org/pdf/1711.10398.pdf
    数据集官网:https://captain-whu.github.io/DOTA/dataset.html

    DOTA数据集总共有3个版本

    DOTAV1.0

    • 类别数目:15
    • 类别名称:plane, ship, storage tank, baseball diamond, tennis court, basketball court, ground track field, harbor, bridge, large vehicle, small vehicle, helicopter, roundabout, soccer ball field , swimming pool

    DOTAV1.5

    • 类别数目:16
    • 类别名称:plane, ship, storage tank, baseball diamond, tennis court, basketball court, ground track field, harbor, bridge, large vehicle, small vehicle, helicopter, roundabout, soccer ball field, swimming pool , container crane

    DOTAV2.0

    • 类别数目:18
    • 类别名称:plane, ship, storage tank, baseball diamond, tennis court, basketball court, ground track field, harbor, bridge, large vehicle, small vehicle, helicopter, roundabout, soccer ball field, swimming pool, container crane, airport , helipad

    本实验所使用的是DOTAV2.0版本,同样备份在我的GitHub上。
    https://github.com/zstar1003/Dataset

    图片分割

    图片分割就是将大图切成一块块小图,同时需要注意将标签进行转换。
    另外,为了防止目标被切断,每两个分割图有部分区域重合,具体的分割策略可以看我下方绘制的示意图。

    在这里插入图片描述
    分割代码使用的是参考项目提供的split.py这个程序。
    这里需指定下列参数:

    • 输入图片文件夹路径
    • 输出图片文件夹路径
    • gap:两个子图的重合宽度
    • subsize:子图大小
    • num_process:线程数

    完整代码:

    import os
    import codecs
    import numpy as np
    import math
    from dota_utils import GetFileFromThisRootDir
    import cv2
    import shapely.geometry as shgeo
    import dota_utils as util
    import copy
    from multiprocessing import Pool
    from functools import partial
    import time
    
    
    def choose_best_pointorder_fit_another(poly1, poly2):
        """
            To make the two polygons best fit with each point
        """
        x1 = poly1[0]
        y1 = poly1[1]
        x2 = poly1[2]
        y2 = poly1[3]
        x3 = poly1[4]
        y3 = poly1[5]
        x4 = poly1[6]
        y4 = poly1[7]
        combinate = [np.array([x1, y1, x2, y2, x3, y3, x4, y4]), np.array([x2, y2, x3, y3, x4, y4, x1, y1]),
                     np.array([x3, y3, x4, y4, x1, y1, x2, y2]), np.array([x4, y4, x1, y1, x2, y2, x3, y3])]
        dst_coordinate = np.array(poly2)
        distances = np.array([np.sum((coord - dst_coordinate) ** 2) for coord in combinate])
        sorted = distances.argsort()
        return combinate[sorted[0]]
    
    
    def cal_line_length(point1, point2):
        return math.sqrt(math.pow(point1[0] - point2[0], 2) + math.pow(point1[1] - point2[1], 2))
    
    
    def split_single_warp(name, split_base, rate, extent):
        split_base.SplitSingle(name, rate, extent)
    
    
    class splitbase():
        def __init__(self,
                     basepath,
                     outpath,
                     code='utf-8',
                     gap=512,
                     subsize=1024,
                     thresh=0.7,
                     choosebestpoint=True,
                     ext='.png',
                     padding=True,
                     num_process=8
                     ):
            """
            :param basepath: base path for dota data
            :param outpath: output base path for dota data,
            the basepath and outputpath have the similar subdirectory, 'images' and 'labelTxt'
            :param code: encodeing format of txt file
            :param gap: overlap between two patches
            :param subsize: subsize of patch
            :param thresh: the thresh determine whether to keep the instance if the instance is cut down in the process of split
            :param choosebestpoint: used to choose the first point for the
            :param ext: ext for the image format
            :param padding: if to padding the images so that all the images have the same size
            """
            self.basepath = basepath
            self.outpath = outpath
            self.code = code
            self.gap = gap
            self.subsize = subsize
            self.slide = self.subsize - self.gap
            self.thresh = thresh
            self.imagepath = os.path.join(self.basepath, 'images')
            self.labelpath = os.path.join(self.basepath, 'labelTxt')
            self.outimagepath = os.path.join(self.outpath, 'images')
            self.outlabelpath = os.path.join(self.outpath, 'labelTxt')
            self.choosebestpoint = choosebestpoint
            self.ext = ext
            self.padding = padding
            self.num_process = num_process
            self.pool = Pool(num_process)
            print('padding:', padding)
    
            # pdb.set_trace()
            if not os.path.isdir(self.outpath):
                os.mkdir(self.outpath)
            if not os.path.isdir(self.outimagepath):
                # pdb.set_trace()
                os.mkdir(self.outimagepath)
            if not os.path.isdir(self.outlabelpath):
                os.mkdir(self.outlabelpath)
            # pdb.set_trace()
    
        ## point: (x, y), rec: (xmin, ymin, xmax, ymax)
        # def __del__(self):
        #     self.f_sub.close()
        ## grid --> (x, y) position of grids
        def polyorig2sub(self, left, up, poly):
            polyInsub = np.zeros(len(poly))
            for i in range(int(len(poly) / 2)):
                polyInsub[i * 2] = int(poly[i * 2] - left)
                polyInsub[i * 2 + 1] = int(poly[i * 2 + 1] - up)
            return polyInsub
    
        def calchalf_iou(self, poly1, poly2):
            """
                It is not the iou on usual, the iou is the value of intersection over poly1
            """
            inter_poly = poly1.intersection(poly2)
            inter_area = inter_poly.area
            poly1_area = poly1.area
            half_iou = inter_area / poly1_area
            return inter_poly, half_iou
    
        def saveimagepatches(self, img, subimgname, left, up):
            subimg = copy.deepcopy(img[up: (up + self.subsize), left: (left + self.subsize)])
            outdir = os.path.join(self.outimagepath, subimgname + self.ext)
            h, w, c = np.shape(subimg)
            if (self.padding):
                outimg = np.zeros((self.subsize, self.subsize, 3))
                outimg[0:h, 0:w, :] = subimg
                cv2.imwrite(outdir, outimg)
            else:
                cv2.imwrite(outdir, subimg)
    
        def GetPoly4FromPoly5(self, poly):
            distances = [cal_line_length((poly[i * 2], poly[i * 2 + 1]), (poly[(i + 1) * 2], poly[(i + 1) * 2 + 1])) for i
                         in range(int(len(poly) / 2 - 1))]
            distances.append(cal_line_length((poly[0], poly[1]), (poly[8], poly[9])))
            pos = np.array(distances).argsort()[0]
            count = 0
            outpoly = []
            while count < 5:
                # print('count:', count)
                if (count == pos):
                    outpoly.append((poly[count * 2] + poly[(count * 2 + 2) % 10]) / 2)
                    outpoly.append((poly[(count * 2 + 1) % 10] + poly[(count * 2 + 3) % 10]) / 2)
                    count = count + 1
                elif (count == (pos + 1) % 5):
                    count = count + 1
                    continue
    
                else:
                    outpoly.append(poly[count * 2])
                    outpoly.append(poly[count * 2 + 1])
                    count = count + 1
            return outpoly
    
        def savepatches(self, resizeimg, objects, subimgname, left, up, right, down):
            outdir = os.path.join(self.outlabelpath, subimgname + '.txt')
            mask_poly = []
            imgpoly = shgeo.Polygon([(left, up), (right, up), (right, down),
                                     (left, down)])
            with codecs.open(outdir, 'w', self.code) as f_out:
                for obj in objects:
                    gtpoly = shgeo.Polygon([(obj['poly'][0], obj['poly'][1]),
                                            (obj['poly'][2], obj['poly'][3]),
                                            (obj['poly'][4], obj['poly'][5]),
                                            (obj['poly'][6], obj['poly'][7])])
                    if (gtpoly.area <= 0):
                        continue
                    inter_poly, half_iou = self.calchalf_iou(gtpoly, imgpoly)
    
                    # print('writing...')
                    if (half_iou == 1):
                        polyInsub = self.polyorig2sub(left, up, obj['poly'])
                        outline = ' '.join(list(map(str, polyInsub)))
                        outline = outline + ' ' + obj['name'] + ' ' + str(obj['difficult'])
                        f_out.write(outline + '\n')
                    elif (half_iou > 0):
                        # elif (half_iou > self.thresh):
                        ##  print('<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<')
                        inter_poly = shgeo.polygon.orient(inter_poly, sign=1)
                        out_poly = list(inter_poly.exterior.coords)[0: -1]
                        if len(out_poly) < 4:
                            continue
    
                        out_poly2 = []
                        for i in range(len(out_poly)):
                            out_poly2.append(out_poly[i][0])
                            out_poly2.append(out_poly[i][1])
    
                        if (len(out_poly) == 5):
                            # print('==========================')
                            out_poly2 = self.GetPoly4FromPoly5(out_poly2)
                        elif (len(out_poly) > 5):
                            """
                                if the cut instance is a polygon with points more than 5, we do not handle it currently
                            """
                            continue
                        if (self.choosebestpoint):
                            out_poly2 = choose_best_pointorder_fit_another(out_poly2, obj['poly'])
    
                        polyInsub = self.polyorig2sub(left, up, out_poly2)
    
                        for index, item in enumerate(polyInsub):
                            if (item <= 1):
                                polyInsub[index] = 1
                            elif (item >= self.subsize):
                                polyInsub[index] = self.subsize
                        outline = ' '.join(list(map(str, polyInsub)))
                        if (half_iou > self.thresh):
                            outline = outline + ' ' + obj['name'] + ' ' + str(obj['difficult'])
                        else:
                            ## if the left part is too small, label as '2'
                            outline = outline + ' ' + obj['name'] + ' ' + '2'
                        f_out.write(outline + '\n')
                    # else:
                    #   mask_poly.append(inter_poly)
            self.saveimagepatches(resizeimg, subimgname, left, up)
    
        def SplitSingle(self, name, rate, extent):
            """
                split a single image and ground truth
            :param name: image name
            :param rate: the resize scale for the image
            :param extent: the image format
            :return:
            """
            img = cv2.imread(os.path.join(self.imagepath, name + extent))
            if np.shape(img) == ():
                return
            fullname = os.path.join(self.labelpath, name + '.txt')
            objects = util.parse_dota_poly2(fullname)
            for obj in objects:
                obj['poly'] = list(map(lambda x: rate * x, obj['poly']))
                # obj['poly'] = list(map(lambda x: ([2 * y for y in x]), obj['poly']))
    
            if (rate != 1):
                resizeimg = cv2.resize(img, None, fx=rate, fy=rate, interpolation=cv2.INTER_CUBIC)
            else:
                resizeimg = img
            outbasename = name + '__' + str(rate) + '__'
            weight = np.shape(resizeimg)[1]
            height = np.shape(resizeimg)[0]
    
            left, up = 0, 0
            while (left < weight):
                if (left + self.subsize >= weight):
                    left = max(weight - self.subsize, 0)
                up = 0
                while (up < height):
                    if (up + self.subsize >= height):
                        up = max(height - self.subsize, 0)
                    right = min(left + self.subsize, weight - 1)
                    down = min(up + self.subsize, height - 1)
                    subimgname = outbasename + str(left) + '___' + str(up)
                    # self.f_sub.write(name + ' ' + subimgname + ' ' + str(left) + ' ' + str(up) + '\n')
                    self.savepatches(resizeimg, objects, subimgname, left, up, right, down)
                    if (up + self.subsize >= height):
                        break
                    else:
                        up = up + self.slide
                if (left + self.subsize >= weight):
                    break
                else:
                    left = left + self.slide
    
        def splitdata(self, rate):
            """
            :param rate: resize rate before cut
            """
            imagelist = GetFileFromThisRootDir(self.imagepath)
            imagenames = [util.custombasename(x) for x in imagelist if (util.custombasename(x) != 'Thumbs')]
            if self.num_process == 1:
                for name in imagenames:
                    self.SplitSingle(name, rate, self.ext)
            else:
    
                # worker = partial(self.SplitSingle, rate=rate, extent=self.ext)
                worker = partial(split_single_warp, split_base=self, rate=rate, extent=self.ext)
                self.pool.map(worker, imagenames)
    
        def __getstate__(self):
            self_dict = self.__dict__.copy()
            del self_dict['pool']
            return self_dict
    
        def __setstate__(self, state):
            self.__dict__.update(state)
    
    
    if __name__ == '__main__':
        split = splitbase('D:/Dataset/DOTA-v2.0/train',
                          'D:/Dataset/DOTA-v2.0/trainsplit',
                          gap=200,
                          subsize=1024,
                          num_process=8
                          )
        split.splitdata(1)
    
        split = splitbase('D:/Dataset/DOTA-v2.0/val',
                          'D:/Dataset/DOTA-v2.0/valsplit',
                          gap=200,
                          subsize=1024,
                          num_process=8
                          )
        split.splitdata(1)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
    • 168
    • 169
    • 170
    • 171
    • 172
    • 173
    • 174
    • 175
    • 176
    • 177
    • 178
    • 179
    • 180
    • 181
    • 182
    • 183
    • 184
    • 185
    • 186
    • 187
    • 188
    • 189
    • 190
    • 191
    • 192
    • 193
    • 194
    • 195
    • 196
    • 197
    • 198
    • 199
    • 200
    • 201
    • 202
    • 203
    • 204
    • 205
    • 206
    • 207
    • 208
    • 209
    • 210
    • 211
    • 212
    • 213
    • 214
    • 215
    • 216
    • 217
    • 218
    • 219
    • 220
    • 221
    • 222
    • 223
    • 224
    • 225
    • 226
    • 227
    • 228
    • 229
    • 230
    • 231
    • 232
    • 233
    • 234
    • 235
    • 236
    • 237
    • 238
    • 239
    • 240
    • 241
    • 242
    • 243
    • 244
    • 245
    • 246
    • 247
    • 248
    • 249
    • 250
    • 251
    • 252
    • 253
    • 254
    • 255
    • 256
    • 257
    • 258
    • 259
    • 260
    • 261
    • 262
    • 263
    • 264
    • 265
    • 266
    • 267
    • 268
    • 269
    • 270
    • 271
    • 272
    • 273
    • 274
    • 275
    • 276
    • 277
    • 278
    • 279
    • 280
    • 281
    • 282
    • 283
    • 284
    • 285
    • 286
    • 287
    • 288
    • 289
    • 290
    • 291
    • 292
    • 293
    • 294
    • 295
    • 296
    • 297
    • 298
    • 299
    • 300

    标签转换

    DOTA数据集的标签并不符合YOLO的要求,需要进行转换,如下图所示,需要将左侧的原始标签转换成右侧的YOLO格式。

    在这里插入图片描述
    这里使用的是参考程序中的YOLO_Transform.py这个脚本,同时需注意,需要在dota_utils.py中修改类别名称wordname_18
    YOLO_Transform.py

    import dota_utils as util
    import os
    import numpy as np
    from PIL import Image
    
    from PIL import ImageFile
    ImageFile.LOAD_TRUNCATED_IMAGES = True
    Image.MAX_IMAGE_PIXELS = None
    
    
    ## trans dota format to format YOLO(darknet) required
    def dota2darknet(imgpath, txtpath, dstpath, extractclassname):
        """
        :param imgpath: the path of images
        :param txtpath: the path of txt in dota format
        :param dstpath: the path of txt in YOLO format
        :param extractclassname: the category you selected
        :return:
        """
        filelist = util.GetFileFromThisRootDir(txtpath)
        for fullname in filelist:
            objects = util.parse_dota_poly(fullname)
            name = os.path.splitext(os.path.basename(fullname))[0]
            img_fullname = os.path.join(imgpath, name + '.png')
            img = Image.open(img_fullname)
            img_w, img_h = img.size
            # print img_w,img_h
            with open(os.path.join(dstpath, name + '.txt'), 'w') as f_out:
                for obj in objects:
                    poly = obj['poly']
                    bbox = np.array(util.dots4ToRecC(poly, img_w, img_h))
                    if (sum(bbox <= 0) + sum(bbox >= 1)) >= 1:
                        continue
                    if (obj['name'] in extractclassname):
                        id = extractclassname.index(obj['name'])
                    else:
                        continue
                    outline = str(id) + ' ' + ' '.join(list(map(str, bbox)))
                    f_out.write(outline + '\n')
    
    
    if __name__ == '__main__':
        dota2darknet('C:/Users/xy/Desktop/Work/upload/DOTA/train/images',
                     'C:/Users/xy/Desktop/Work/upload/DOTA/train/labels1',
                     'C:/Users/xy/Desktop/Work/upload/DOTA/train/labels',
                     util.wordname_18)
    
        dota2darknet('C:/Users/xy/Desktop/Work/upload/DOTA/val/images',
                     'C:/Users/xy/Desktop/Work/upload/DOTA/val/labels1',
                     'C:/Users/xy/Desktop/Work/upload/DOTA/val/labels',
                     util.wordname_18)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51

    dota_utils.py

    import sys
    import codecs
    import numpy as np
    import shapely.geometry as shgeo
    import os
    import re
    import math
    """
        some basic functions which are useful for process DOTA data
    """
    
    wordname_18 = [
      'airport',
      'small-vehicle',
      'large-vehicle',
      'plane',
      'storage-tank',
      'ship',
      'harbor',
      'ground-track-field',
      'soccer-ball-field',
      'tennis-court',
      'swimming-pool',
      'baseball-diamond',
      'roundabout',
      'basketball-court',
      'bridge',
      'helicopter',
      'container-crane',
      'helipad']
    
    
    def custombasename(fullname):
        return os.path.basename(os.path.splitext(fullname)[0])
    
    def GetFileFromThisRootDir(dir,ext = None):
      allfiles = []
      needExtFilter = (ext != None)
      for root,dirs,files in os.walk(dir):
        for filespath in files:
          filepath = os.path.join(root, filespath)
          extension = os.path.splitext(filepath)[1][1:]
          if needExtFilter and extension in ext:
            allfiles.append(filepath)
          elif not needExtFilter:
            allfiles.append(filepath)
      return allfiles
    
    def TuplePoly2Poly(poly):
        outpoly = [poly[0][0], poly[0][1],
                           poly[1][0], poly[1][1],
                           poly[2][0], poly[2][1],
                           poly[3][0], poly[3][1]
                           ]
        return outpoly
    
    def parse_dota_poly(filename):
        """
            parse the dota ground truth in the format:
            [(x1, y1), (x2, y2), (x3, y3), (x4, y4)]
        """
        objects = []
        #print('filename:', filename)
        f = []
        if (sys.version_info >= (3, 5)):
            fd = open(filename, 'r')
            f = fd
        elif (sys.version_info >= 2.7):
            fd = codecs.open(filename, 'r')
            f = fd
        # count = 0
        while True:
            line = f.readline()
            # count = count + 1
            # if count < 2:
            #     continue
            if line:
                splitlines = line.strip().split(' ')
                object_struct = {}
                ### clear the wrong name after check all the data
                #if (len(splitlines) >= 9) and (splitlines[8] in classname):
                if (len(splitlines) < 9):
                    continue
                if (len(splitlines) >= 9):
                        object_struct['name'] = splitlines[8]
                if (len(splitlines) == 9):
                    object_struct['difficult'] = '0'
                elif (len(splitlines) >= 10):
                    # if splitlines[9] == '1':
                    # if (splitlines[9] == 'tr'):
                    #     object_struct['difficult'] = '1'
                    # else:
                    object_struct['difficult'] = splitlines[9]
                    # else:
                    #     object_struct['difficult'] = 0
                object_struct['poly'] = [(float(splitlines[0]), float(splitlines[1])),
                                         (float(splitlines[2]), float(splitlines[3])),
                                         (float(splitlines[4]), float(splitlines[5])),
                                         (float(splitlines[6]), float(splitlines[7]))
                                         ]
                gtpoly = shgeo.Polygon(object_struct['poly'])
                object_struct['area'] = gtpoly.area
                # poly = list(map(lambda x:np.array(x), object_struct['poly']))
                # object_struct['long-axis'] = max(distance(poly[0], poly[1]), distance(poly[1], poly[2]))
                # object_struct['short-axis'] = min(distance(poly[0], poly[1]), distance(poly[1], poly[2]))
                # if (object_struct['long-axis'] < 15):
                #     object_struct['difficult'] = '1'
                #     global small_count
                #     small_count = small_count + 1
                objects.append(object_struct)
            else:
                break
        return objects
    
    def dots4ToRecC(poly, img_w, img_h):
        xmin, ymin, xmax, ymax = dots4ToRec4(poly)
        x = (xmin + xmax)/2
        y = (ymin + ymax)/2
        w = xmax - xmin
        h = ymax - ymin
        return x/img_w, y/img_h, w/img_w, h/img_h
    
    def parse_dota_poly2(filename):
        """
            parse the dota ground truth in the format:
            [x1, y1, x2, y2, x3, y3, x4, y4]
        """
        objects = parse_dota_poly(filename)
        for obj in objects:
            obj['poly'] = TuplePoly2Poly(obj['poly'])
            obj['poly'] = list(map(int, obj['poly']))
        return objects
    
    def parse_dota_rec(filename):
        """
            parse the dota ground truth in the bounding box format:
            "xmin, ymin, xmax, ymax"
        """
        objects = parse_dota_poly(filename)
        for obj in objects:
            poly = obj['poly']
            bbox = dots4ToRec4(poly)
            obj['bndbox'] = bbox
        return objects
    ## bounding box transfer for varies format
    
    def dots4ToRec4(poly):
        xmin, xmax, ymin, ymax = min(poly[0][0], min(poly[1][0], min(poly[2][0], poly[3][0]))), \
                                max(poly[0][0], max(poly[1][0], max(poly[2][0], poly[3][0]))), \
                                 min(poly[0][1], min(poly[1][1], min(poly[2][1], poly[3][1]))), \
                                 max(poly[0][1], max(poly[1][1], max(poly[2][1], poly[3][1])))
        return xmin, ymin, xmax, ymax
    def dots4ToRec8(poly):
        xmin, ymin, xmax, ymax = dots4ToRec4(poly)
        return xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax
        #return dots2ToRec8(dots4ToRec4(poly))
    def dots2ToRec8(rec):
        xmin, ymin, xmax, ymax = rec[0], rec[1], rec[2], rec[3]
        return xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax
    
    def groundtruth2Task1(srcpath, dstpath):
        filelist = GetFileFromThisRootDir(srcpath)
        # names = [custombasename(x.strip())for x in filelist]
        filedict = {}
        for cls in wordname_15:
            fd = open(os.path.join(dstpath, 'Task1_') + cls + r'.txt', 'w')
            filedict[cls] = fd
        for filepath in filelist:
            objects = parse_dota_poly2(filepath)
    
            subname = custombasename(filepath)
            pattern2 = re.compile(r'__([\d+\.]+)__\d+___')
            rate = re.findall(pattern2, subname)[0]
    
            for obj in objects:
                category = obj['name']
                difficult = obj['difficult']
                poly = obj['poly']
                if difficult == '2':
                    continue
                if rate == '0.5':
                    outline = custombasename(filepath) + ' ' + '1' + ' ' + ' '.join(map(str, poly))
                elif rate == '1':
                    outline = custombasename(filepath) + ' ' + '0.8' + ' ' + ' '.join(map(str, poly))
                elif rate == '2':
                    outline = custombasename(filepath) + ' ' + '0.6' + ' ' + ' '.join(map(str, poly))
    
                filedict[category].write(outline + '\n')
    
    def Task2groundtruth_poly(srcpath, dstpath):
        thresh = 0.1
        filedict = {}
        Tasklist = GetFileFromThisRootDir(srcpath, '.txt')
    
        for Taskfile in Tasklist:
            idname = custombasename(Taskfile).split('_')[-1]
            # idname = datamap_inverse[idname]
            f = open(Taskfile, 'r')
            lines = f.readlines()
            for line in lines:
                if len(line) == 0:
                    continue
                # print('line:', line)
                splitline = line.strip().split(' ')
                filename = splitline[0]
                confidence = splitline[1]
                bbox = splitline[2:]
                if float(confidence) > thresh:
                    if filename not in filedict:
                        # filedict[filename] = codecs.open(os.path.join(dstpath, filename + '.txt'), 'w', 'utf_16')
                        filedict[filename] = codecs.open(os.path.join(dstpath, filename + '.txt'), 'w')
                    # poly = util.dots2ToRec8(bbox)
                    poly = bbox
                    #               filedict[filename].write(' '.join(poly) + ' ' + idname + '_' + str(round(float(confidence), 2)) + '\n')
                # print('idname:', idname)
    
                # filedict[filename].write(' '.join(poly) + ' ' + idname + '_' + str(round(float(confidence), 2)) + '\n')
    
                filedict[filename].write(' '.join(poly) + ' ' + idname + '\n')
    
    
    def polygonToRotRectangle(bbox):
        """
        :param bbox: The polygon stored in format [x1, y1, x2, y2, x3, y3, x4, y4]
        :return: Rotated Rectangle in format [cx, cy, w, h, theta]
        """
        bbox = np.array(bbox,dtype=np.float32)
        bbox = np.reshape(bbox,newshape=(2,4),order='F')
        angle = math.atan2(-(bbox[0,1]-bbox[0,0]),bbox[1,1]-bbox[1,0])
    
        center = [[0],[0]]
    
        for i in range(4):
            center[0] += bbox[0,i]
            center[1] += bbox[1,i]
    
        center = np.array(center,dtype=np.float32)/4.0
    
        R = np.array([[math.cos(angle), -math.sin(angle)], [math.sin(angle), math.cos(angle)]], dtype=np.float32)
    
        normalized = np.matmul(R.transpose(),bbox-center)
    
        xmin = np.min(normalized[0,:])
        xmax = np.max(normalized[0,:])
        ymin = np.min(normalized[1,:])
        ymax = np.max(normalized[1,:])
    
        w = xmax - xmin + 1
        h = ymax - ymin + 1
    
        return [float(center[0]),float(center[1]),w,h,angle]
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
    • 168
    • 169
    • 170
    • 171
    • 172
    • 173
    • 174
    • 175
    • 176
    • 177
    • 178
    • 179
    • 180
    • 181
    • 182
    • 183
    • 184
    • 185
    • 186
    • 187
    • 188
    • 189
    • 190
    • 191
    • 192
    • 193
    • 194
    • 195
    • 196
    • 197
    • 198
    • 199
    • 200
    • 201
    • 202
    • 203
    • 204
    • 205
    • 206
    • 207
    • 208
    • 209
    • 210
    • 211
    • 212
    • 213
    • 214
    • 215
    • 216
    • 217
    • 218
    • 219
    • 220
    • 221
    • 222
    • 223
    • 224
    • 225
    • 226
    • 227
    • 228
    • 229
    • 230
    • 231
    • 232
    • 233
    • 234
    • 235
    • 236
    • 237
    • 238
    • 239
    • 240
    • 241
    • 242
    • 243
    • 244
    • 245
    • 246
    • 247
    • 248
    • 249
    • 250
    • 251

    图片格式转换

    另外作者还提供了一个脚本用于转换图片格式,比如将png格式转成jpg,使用opencv进行实现。
    这里虽然没用到,还是放置在此,如需训练自己的数据集可以使用。
    imagetrans.py

    import dota_utils as util
    import cv2
    import os
    
    # this code is used to convert image formats
    # from PNG to JPG
    def imageformatTrans(srcpath, dstpath, format):
        filelist = util.GetFileFromThisRootDir(srcpath)
        for fullname in filelist:
            img = cv2.imread(fullname)
            basename = util.custombasename(fullname)
            dstname = os.path.join(dstpath, basename + format)
            cv2.imwrite(dstname, img)
    
    if __name__ == '__main__':
        # an example
        imageformatTrans('path1', 'path2',
                         '.jpg')
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18

    另外,如果下载的是我提供的数据集,无需进行这些操作,我将原数据集和标签/分割数据集和标签全部转换完成,可以直接调入YOLOv5中使用。

    训练结果

    下图是我使用YOLOv5l模型的训练结果,可以看到训练100个epoch之后,模型基本收敛。

    未分割训练效果:
    在这里插入图片描述

    分割之后的训练效果:

    请添加图片描述

    数据对比:

    模型mAP@.5mAP@.5:.95
    YOLOv5l(未分割)13.5%5.52%
    YOLOv5l(分割之后)33.5%18.6%

    结果可视化展示

    这里使用两张DOTA-test中的图片做对比测试。

    未分割前:
    请添加图片描述

    请添加图片描述
    分割后:

    请添加图片描述
    请添加图片描述

    可以看到区别还是相当明显的,分割之后尽管还有少部分目标漏检,大部分目标都能准确得检测出来。

  • 相关阅读:
    2.9 GBDT模型(下篇)
    为什么我从 Swift 转向 Flutter,你也应该这样做
    第五十四周总结——WebRTC录制音频
    ai软件基础教程自学网,怎么快速学会ai软件
    R语言ggplot2可视化斜率图、对比同一数据对象前后(before、after)两个状态的差异(Slope Chart)
    Docker从入门到部署项目
    Linux下IIC子系统和触摸屏驱动
    QQ机器人-nonebot
    MySQL基础篇总结
    Python 数据操作教程之如何在 Python 中转置矩阵
  • 原文地址:https://blog.csdn.net/qq1198768105/article/details/126299646