• YOLOv7中的数据集处理【代码分析】

    本文章主要是针对yolov7中数据集处理部分代码进行解析(和yolov5是一样的),也是可以更好的理解训练中送入的数据集到底是什么样子的。数据集的处理离不开两个类,一个是Dataset(from torch.utils.data import Dataset),一个是DataLoader(from torch.utils.data.dataloader import DataLoader),不论什么样的算法,在处理数据集的时候都需要继承这两个类来重写自己的数据集(在我另外的文章中有讲这两个类的使用)。


    1. # Trainloader 训练数据集的处理
    2. dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=rank,world_size=opt.world_size, workers=opt.workers,image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train:'))




    1. def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False,
    2. rank=-1, world_size=1, workers=8, image_weights=False, quad=False, prefix=''):
    3. # Make sure only the first process in DDP process the dataset first, and the following others can use the cache
    4. with torch_distributed_zero_first(rank):
    5. dataset = LoadImagesAndLabels(path, imgsz, batch_size,
    6. augment=augment, # augment images
    7. hyp=hyp, # augmentation hyperparameters
    8. rect=rect, # rectangular training
    9. cache_images=cache,
    10. single_cls=opt.single_cls,
    11. stride=int(stride),
    12. pad=pad,
    13. image_weights=image_weights,
    14. prefix=prefix)
    15. batch_size = min(batch_size, len(dataset))
    16. nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers
    17. sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
    18. loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader
    19. # Use torch.utils.data.DataLoader() if dataset.properties will update during training else InfiniteDataLoader()
    20. dataloader = loader(dataset,
    21. batch_size=batch_size,
    22. num_workers=nw,
    23. sampler=sampler,
    24. pin_memory=True,
    25. collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn)
    26. return dataloader, dataset




    1. def torch_distributed_zero_first(local_rank: int):
    2. """
    3. Decorator to make all processes in distributed training wait for each local_master to do something.
    4. torch_distributed_zero_first是在create_dataloader函数中调用的,如果执行create_dataloader()函数的进程不是主进程,
    5. 即rank不等于0或者-1,上下文管理器会执行相应的torch.distributed.barrier(),设置一个阻塞栅栏,让此进程处于等待状态,等待所有进程到达栅栏处
    6. (包括主进程数据处理完毕);
    7. 如果执行create_dataloader()函数的进程是主进程,其会直接去读取数据并处理,然后其处理结束之后会接着遇到torch.distributed.barrier(),
    8. 此时,所有进程都到达了当前的栅栏处,这样所有进程就达到了同步,并同时得到释放。
    9. """
    10. if local_rank not in [-1, 0]:
    11. torch.distributed.barrier()
    12. yield
    13. if local_rank == 0:
    14. torch.distributed.barrier()


    再接下来是LoadImagesAndLabels函数,通过该函数可以加载数据集。该函数是继承Dataset写的,需要实现父类的def __len__(self)【用来返回数据集长度】函数和def __getitem__(self,index)【通过索引对数据集样本进行处理】。


    def __init__中主要是一些初始化参数,path是我们生成的train.txt文件【我这里是在dataset/train.txt】。augment是否采用数据增强。

    1. class LoadImagesAndLabels(Dataset): # for training/testing
    2. # path为txt路径
    3. def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
    4. cache_images=False, single_cls=False, stride=32, pad=0.0, prefix=''):
    5. self.img_size = img_size
    6. self.augment = augment
    7. self.hyp = hyp
    8. self.image_weights = image_weights
    9. self.rect = False if image_weights else rect
    10. self.mosaic = self.augment and not self.rect # load 4 images at a time into a mosaic (only during training)
    11. self.mosaic_border = [-img_size // 2, -img_size // 2]
    12. self.stride = stride
    13. self.path = path



    1. try:
    2. f = [] # image files
    3. for p in path if isinstance(path, list) else [path]:
    4. p = Path(p) # os-agnostic
    5. if p.is_dir(): # dir 判断改路径是否为目录
    6. f += glob.glob(str(p / '**' / '*.*'), recursive=True)
    7. # f = list(p.rglob('**/*.*')) # pathlib
    8. elif p.is_file(): # file
    9. with open(p, 'r') as t:
    10. t = t.read().strip().splitlines() # splitlines = readlines()
    11. parent = str(p.parent) + os.sep
    12. f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path
    13. # f += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib)
    14. else:
    15. raise Exception(f'{prefix}{p} does not exist')
    16. self.img_files = sorted([x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in img_formats])
    17. # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in img_formats]) # pathlib
    18. assert self.img_files, f'{prefix}No images found'
    19. except Exception as e:
    20. raise Exception(f'{prefix}Error loading data from {path}: {e}\nSee {help_url}')





    ['F:\\yolov7/dataset/images/0.jpg', 'F:\\yolov7/dataset/images/000000.jpg', 'F:\\yolov7/dataset/images/000001.jpg', 'F:\\yolov7/dataset/images/000002.jpg', 'F:\\yolov7/dataset/images/000003.jpg', 'F:\\yolov7/dataset/images/000004.jpg', 'F:\\yolov7/dataset/images/000005.jpg', 'F:\\yolov7/dataset/images/000008.jpg', 'F:\\yolov7/dataset/images/00001.jpg', 'F:\\yolov7/dataset/images/000011.jpg', 'F:\\yolov7/dataset/images/000012.jpg', 'F:\\yolov7/dataset/images/000013.jpg', 'F:\\yolov7/dataset/images/000014.jpg', 'F:\\yolov7/dataset/images/000017.jpg', 'F:\\yolov7/dataset/images/000021.jpg', 'F:\\yolov7/dataset/images/000023.jpg', 'F:\\yolov7/dataset/images/000024.jpg', 'F:\\yolov7/dataset/images/000025.jpg', 'F:\\yolov7/dataset/images/000026.jpg', 'F:\\yolov7/dataset/images/000027.jpg', 'F:\\yolov7/dataset/images/000028.jpg', 'F:\\yolov7/dataset/images/000031.jpg', 

    self.img_files = sorted([x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in img_formats])


    check cache


    1. # Check cache
    2. self.label_files = img2label_paths(self.img_files) # labels
    3. cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache') # cached labels
    4. if cache_path.is_file():
    5. cache, exists = torch.load(cache_path), True # load
    6. #if cache['hash'] != get_hash(self.label_files + self.img_files) or 'version' not in cache: # changed
    7. # cache, exists = self.cache_labels(cache_path, prefix), False # re-cache
    8. else:
    9. cache, exists = self.cache_labels(cache_path, prefix), False # cache


    1. def img2label_paths(img_paths):
    2. # Define label paths as a function of image paths
    3. sa, sb = os.sep + 'images' + os.sep, os.sep + 'labels' + os.sep # /images/, /labels/ substrings
    4. return ['txt'.join(x.replace(sa, sb, 1).rsplit(x.split('.')[-1], 1)) for x in img_paths]



    0 0.5697115384615385 0.6442307692307693 0.44711538461538464 0.6538461538461539


    cache, exists = self.cache_labels(cache_path, prefix), False






    1. # Read cache
    2. cache.pop('hash') # remove hash
    3. cache.pop('version') # remove version
    4. labels, shapes, self.segments = zip(*cache.values())
    5. self.labels = list(labels)
    6. self.shapes = np.array(shapes, dtype=np.float64)
    7. self.img_files = list(cache.keys()) # update
    8. self.label_files = img2label_paths(cache.keys()) # update
    9. if single_cls:
    10. for x in self.labels:
    11. x[:, 0] = 0
    12. n = len(shapes) # number of images
    13. bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index
    14. nb = bi[-1] + 1 # number of batches
    15. self.batch = bi # batch index of image
    16. self.n = n
    17. self.indices = range(n)


    1. # Cache images into memory for faster training (WARNING: large datasets may exceed system RAM)
    2. self.imgs = [None] * n
    3. if cache_images:
    4. if cache_images == 'disk':
    5. self.im_cache_dir = Path(Path(self.img_files[0]).parent.as_posix() + '_npy')
    6. self.img_npy = [self.im_cache_dir / Path(f).with_suffix('.npy').name for f in self.img_files]
    7. self.im_cache_dir.mkdir(parents=True, exist_ok=True)
    8. gb = 0 # Gigabytes of cached images
    9. self.img_hw0, self.img_hw = [None] * n, [None] * n
    10. results = ThreadPool(8).imap(lambda x: load_image(*x), zip(repeat(self), range(n)))
    11. pbar = tqdm(enumerate(results), total=n)
    12. for i, x in pbar:
    13. if cache_images == 'disk':
    14. if not self.img_npy[i].exists():
    15. np.save(self.img_npy[i].as_posix(), x[0])
    16. gb += self.img_npy[i].stat().st_size
    17. else:
    18. self.imgs[i], self.img_hw0[i], self.img_hw[i] = x
    19. gb += self.imgs[i].nbytes
    20. pbar.desc = f'{prefix}Caching images ({gb / 1E9:.1f}GB)'
    21. pbar.close()

    数据集处理def __getitem__(self,index)


    1. # Load image
    2. img, (h0, w0), (h, w) = load_image(self, index)
    3. # Letterbox
    4. shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size # final letterboxed shape
    5. img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
    6. shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling
    7. labels = self.labels[index].copy()
    8. if labels.size: # normalized xywh to pixel xyxy format
    9. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])


    1. nL = len(labels) # number of labels
    2. if nL:
    3. labels[:, 1:5] = xyxy2xywh(labels[:, 1:5]) # convert xyxy to xywh
    4. labels[:, [2, 4]] /= img.shape[0] # normalized height 0-1
    5. labels[:, [1, 3]] /= img.shape[1] # normalized width 0-1

    创建一个labels_out 用来存储labels

    labels_out = torch.zeros((nL, 6))



    img = img[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB, to 3x416x416
    img = np.ascontiguousarray(img)


    return torch.from_numpy(img), labels_out, self.img_files[index], shapes



    获得batch_size和number workers 。DistributedSampler是分布式采样器。可以将数据集放在多卡上进行训练,但我这里用的是单卡。每个GPU上数据的划分为:一个epoch数据集数量/num_gpu=每个GPU得到的数量,如果shuffle=True,那么每个GPU得到的数据是随机的,否则是按顺序划分。

    1. batch_size = min(batch_size, len(dataset))
    2. nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers
    3. sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
    4. loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader
    5. # Use torch.utils.data.DataLoader() if dataset.properties will update during training else InfiniteDataLoader()
    6. dataloader = loader(dataset,
    7. batch_size=batch_size,
    8. num_workers=nw,
    9. sampler=sampler,
    10. pin_memory=True,
    11. collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn)
    12. return dataloader, dataset



    1. dataloader = loader(dataset,
    2. batch_size=batch_size,
    3. num_workers=nw,
    4. sampler=sampler,
    5. pin_memory=True,
    6. collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn)



  • 相关阅读:
    从文件下载视角来理解Web API
    Tomcat 的连接器是如何设计的(上)
  • 原文地址:https://blog.csdn.net/z240626191s/article/details/126459123