• 复杂数据没头绪?


     在深度学习模型的训练过程中,数据集是起着至关重要作用的。然而,由于任务的复杂性,深度学习模型的输入数据也有着各种各样的形式,深度学习模型搭建的过程中,如果遇到特别复杂的数据,研究者可能要花费大半的时间在数据集的预处理(包括清洗、加载等过程)中。因此,高效的加载数据集,能给研究者构建一套高效的开发流程。使用过PyTorch的读者都知道,PyTorch框架为我们提供了一套极其便利且高效率的自定义数据加载的接口。用户只需要简单的继承torch.utils.data.Dataset并且在get_item函数和__len__函数,再利用Dataloader进行封装,就可以很简单的实现数据集的自动化加载流程(个人认为设置PyTorch在数据层面上做的超级好的一个点)。

    ● MindSpore数据集加载简介 ● 

    在MindSpore中,mindspore.dataset里面的函数为我们提供了大量的数据集专有加载算子,这些算子经过优化,拥有较好的数据集加载性能。但是,由于MindSpore本身的数据加载都是在C语言层面完成的,用户很难感知到内部进行的具体操作,特别是针对coco这一类较为复杂的数据集时(就是比较黑洞,很难自己掌握)。由于笔者是一个很喜欢把模型训练的每一步都抓在自己手里的一个人,因此除了cifar10、cifar100、imagefolder等经典的数据(结构)时,尽量都希望自己完成数据集的加载流程,以便更好的了解模型模型和数据集。因此,这篇博客将会主要介绍如何使用MindSpore自定自定义类似PyTorch范式的数据集加载流程。

    ● mindspore.dataset.Generator

    Dataset 

    区别用PyTorch,MindSpore并不能像继承Dataset来完成数据集的构建,但是MindSpore为用户提供了一个类似于DataLoader的数据集封装接口。用户可以通过自定义object对象的数据集对象,然后使用GeneratorDataset进行封装,接下来我将以自定义cifar10和imagenet数据集来简单展示使用GeneratorDataset接口的方法。

    ● 自定义cifar10数据集 

    分析格式

    在定义数据集之前,我们首先要做的就是数据集的格式分析。在cifar官网中,我们可以得知数据集的基本格式,还可以通过已有的博客,查看读取cifar10的代码样例。如下图所示是cifar-10-batches-py数据集的目录文件,这里我们主要是关注data_batch和test_batch。

    加载数据

    这里我主要以torchvision中的cifar10数据加载为例,说明构建cifar10数据集的方法。

    1. train_list = [
    2. ['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
    3. ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
    4. ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
    5. ['data_batch_4', '634d18415352ddfa80567beed471001a'],
    6. ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
    7. ]
    8. test_list = [
    9. ['test_batch', '40351d587109b95175f43aff81a1287e'],
    10. ]
    11. ...
    12. if self.train:
    13. downloaded_list = self.train_list
    14. else:
    15. downloaded_list = self.test_list
    16. ...
    17. for file_name, checksum in downloaded_list:
    18. file_path = os.path.join(self.root, self.base_folder, file_name)
    19. with open(file_path, 'rb') as f:
    20. entry = pickle.load(f, encoding='latin1')
    21. self.data.append(entry['data'])
    22. if 'labels' in entry:
    23. self.targets.extend(entry['labels'])
    24. else:
    25. self.targets.extend(entry['fine_labels'])
    26. """可以很容易理解到,数据集文件里面有一个"data"和一个"label"键,分别拿出来就好"""
    27. self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
    28. self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC

     构建cifar10数据集并且完成预处理

    由于cifar10读取进来以后已经是数据形式,因此并不需要想用的图像解码,可以直接使用opencv或者PIL进行处理。这里以cifar10的test数据为例。

    1. import os
    2. import pickle
    3. import numpy as np
    4. import mindspore
    5. from mindspore.dataset import GeneratorDataset
    6. class CIFAR10(object):
    7. train_list = [
    8. 'data_batch_1',
    9. 'data_batch_2',
    10. 'data_batch_3',
    11. 'data_batch_4',
    12. 'data_batch_5',
    13. ]
    14. test_list = [
    15. 'test_batch',
    16. ]
    17. def __init__(self, root, train, transform=None, target_transform=None):
    18. super(CIFAR10, self).__init__()
    19. self.root = root
    20. self.train = train # training set or test set
    21. if self.train:
    22. downloaded_list = self.train_list
    23. else:
    24. downloaded_list = self.test_list
    25. self.data = []
    26. self.targets = []
    27. self.transform = transform
    28. self.target_transform = target_transform
    29. # now load the picked numpy arrays
    30. for file_name in downloaded_list:
    31. file_path = os.path.join(self.root, file_name)
    32. with open(file_path, 'rb') as f:
    33. entry = pickle.load(f, encoding='latin1')
    34. self.data.append(entry['data'])
    35. if 'labels' in entry:
    36. self.targets.extend(entry['labels'])
    37. else:
    38. self.targets.extend(entry['fine_labels'])
    39. self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
    40. self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
    41. def __getitem__(self, index):
    42. """
    43. Args:
    44. index (int): Index
    45. Returns:
    46. tuple: (image, target) where target is index of the target class.
    47. """
    48. img, target = self.data[index], self.targets[index]
    49. # doing this so that it is consistent with all other datasets
    50. # to return a PIL Image
    51. img = Image.fromarray(img)
    52. if self.transform is not None:
    53. img = self.transform(img)
    54. if self.target_transform is not None:
    55. target = self.target_transform(target)
    56. return img, target
    57. def __len__(self):
    58. return len(self.data)
    59. cifar10_test = CIFAR10(root="./cifar10/cifar-10-batches-py", train=False)
    60. cifar10_test = GeneratorDataset(source=cifar10_test, column_names=["image", "label"])
    61. cifar10_test = cifar10_test.batch(128)
    62. for data in cifar10_test.create_dict_iterator():
    63. print(data["image"].shape, data["label"].shape)
    64. (128, 32, 32, 3) (128,)
    65. (128, 32, 32, 3) (128,)
    66. (128, 32, 32, 3) (128,)
    67. (128, 32, 32, 3) (128,)

    可以从上面的代码看到,虽然语言风格不同,但是MindSpore使用GeneratorDataset依然可以为我们提供一套相对便利的数据集加载方式。对于数据集的预处理的transform代码,研究者可以将代码直接通过transform参数传入get_item函数,十分方便;同时也可以使用MindSpore语言风格,通过dataset自带的map函数,对数据集进行预处理,不过前者的语言风格更加Python,推荐使用。

    ● 自定义ImageNet 

    分析格式

    接下来是介绍ImageNet的数据集自定义过程。其实定义ImageNet数据集加载器是非常方便的,因为图像分类的这类数据集往往是具有树状结构,我们只需要[路径,标签]或者是[图像,标签]的数组对传入到get_item函数中,就可以完成数据集的预处理。

    数据加载

    这里就简单引用timm中定义folder的部分代码。

    1. def find_images_and_targets(folder, types=IMG_EXTENSIONS, class_to_idx=None, leaf_name_only=True, sort=True):
    2. labels = []
    3. filenames = []
    4. for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True):
    5. rel_path = os.path.relpath(root, folder) if (root != folder) else ''
    6. label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_')
    7. for f in files:
    8. base, ext = os.path.splitext(f)
    9. if ext.lower() in types:
    10. filenames.append(os.path.join(root, f))
    11. labels.append(label)
    12. if class_to_idx is None:
    13. # building class index
    14. unique_labels = set(labels)
    15. sorted_labels = list(sorted(unique_labels, key=natural_key))
    16. class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
    17. images_and_targets = [(f, class_to_idx[l]) for f, l in zip(filenames, labels) if l in class_to_idx]
    18. if sort:
    19. images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0]))
    20. return images_and_targets, class_to_idx

    可以看到,我们只需要遍历目录,得到images_and_target就好。

    Mixup和Cutmix的使用

    在ImageNet中,我们常常会使用Mixup和Cutmix等数据增强,但是在对齐进行数据增强的时候,数据集已经是变成[batch_size, channel, height, width]形式出来的,在get_item进行数据预处理的函数是针对单个样本的。在PyTorch中,Mixup和Cutmix是在将数据取出,输入模型之前应用的。在MindSpore中,我们只需要在使用dataset.batch函数之后再对数据集进行预处理。具体的代码可以参考我的博客如何用MindSpore实现自动数据增强

    https://blog.csdn.net/qq_31768873/article/details/121283169),这里展示部分代码。

    1. if (mix_up > 0. or cutmix > 0.) and not is_training:
    2. # if use mixup and not training(False), one hot val data label
    3. one_hot = C.OneHot(num_classes=num_classes)
    4. dataset = dataset.map(input_columns="label", num_parallel_workers=num_parallel_workers,
    5. operations=one_hot)
    6. dataset = dataset.batch(batch_size, drop_remainder=True, num_parallel_workers=num_parallel_workers)
    7. if (mix_up > 0. or cutmix > 0.) and is_training:
    8. mixup_fn = Mixup(
    9. mixup_alpha=mix_up, cutmix_alpha=cutmix, cutmix_minmax=None,
    10. prob=mixup_prob, switch_prob=switch_prob, mode=mixup_mode,
    11. label_smoothing=label_smoothing, num_classes=num_classes)
    12. dataset = dataset.map(operations=mixup_fn, input_columns=["image", "label"],
    13. num_parallel_workers=num_parallel_workers)
    14. return dataset

    ● FAQ 

    自定义数据集的时候,千万要注意要重载len函数,没有这个函数,对象是无法感知数据集大小的。 


    ● 总结 

    本文介绍了如何使用GeneratorDataset这个接口自定义MindSpore数据集。虽然MindSpore为我们提供了好用的专有数据算子,但是由于数据加载在C语言层面完成,相对于torchvision来说存在着无法感知的缺陷,因此可以尝试使用GeneratorDataset自定义加载,把握每一步细节。(当然,其实也可以去torchvision搬代码拿GeneratorDataset封装就好~)

    MindSpore官方资料

    GitHub : https://github.com/mindspore-ai/mindspore

    Gitee : https : //gitee.com/mindspore/mindspore

    官方QQ群 : 486831414 

  • 相关阅读:
    Python sort面试题目
    docker Cgroup资源控制
    如何给双系统电脑分区?
    vue介绍
    TinyShell(CSAPP实验)
    mysql5.7安装配置教程(一看就会)
    GLTF在线编辑器
    css3新增伪元素有哪些?
    2251: 【区赛】【海曙2017】波波爱看NBA
    上传文件-读取excel文件数据
  • 原文地址:https://blog.csdn.net/Kenji_Shinji/article/details/125455318