首先MindSpore Data提供了简洁、丰富的数据读取、处理、增强等功能;同时使用读取数据的流程,主要分为三步(使用和PyTorch中数据读取方式类似):
首先加载要使用的数据集,根据实际使用的数据集格式,从以下三种数据集读取方式选取一种即可:
目前已经支持的常用数据集有:MNIST, CIFAR-10, CIFAR-100, VOC, ImageNet, CelebA。如果使用以上开源数据集或者已经将所使用的数据整理为以上标准数据集格式,可以直接使用如下方法加载数据集。以CIFAR-10为例:
- import mindspore.dataset as ds
-
- DATA_DIR = "./cifar-10-batches-bin/"
- cifar_ds = ds.Cifar10Dataset(DATA_DIR)
数据集加载好之后,就可以调用接口create_dict_iterator()创建迭代器读取数据,后面两种方式同理。
- for data in cifar_ds.create_dict_iterator():# In CIFAR-10 dataset, each dictionary of data has keys "image" and "label".
-
- print(data["image"])
- print(data["label"])
目前支持的特定格式数据集为:MindRecord。MindRecord格式的数据读取性能更优,推荐用户将数据转换为MindRecord格式。转换示例如下:
- from mindspore.mindrecord import Cifar10ToMR
-
- cifar10_path = "./cifar-10-batches-py"
- mindrecord_path = "./cifar10.mindrecord"
- cifar10_transformer = Cifar10ToMR(cifar10_path, mindrecord_path)
- cifar10_transformer.transform(["label"])
MindRecord数据加载:
- import mindspore.dataset as ds
-
- CV_FILE_NAME = "./cifar10.mindrecord"
- cifar_ds = ds.MindDataset(dataset_file=CV_FILE_NAME,columns_list=["data","label"], shuffle=True)
提供的自定义数据集加载方式为:GeneratorDataset接口。GeneratorDataset接口需要自己实现一个生成器,生成训练数据和标签,适用于较复杂的任务。
GeneratorDataset()需要传入一个生成器,生成训练数据。
- import mindspore.dataset as ds
-
- class Dataset:
- def __init__(self, image_list, label_list):
- super(Dataset, self).__init__()
- self.imgs = image_list
- self.labels = label_list
-
- def __getitem__(self, index):
- img = Image.open(self.imgs[index]).convert('RGB')
- return img, self.labels[index]
-
- def __len__(self):
- return len(self.imgs)
-
-
- class MySampler():
- def __init__(self, dataset):
- self.__num_data = len(dataset)
-
- def __iter__(self):
- indices = list(range(self.__num_data))
- return iter(indices)
-
- dataset = Dataset(save_image_list, save_label_list)
- sampler = MySampler(dataset)
- cifar_ds = ds.GeneratorDataset(dataset, column_names=["image", "label"], sampler=sampler, shuffle=True)
以上例子中 dataset是一个生成器,产生image和label。
提供 c_transforms 和 py_transforms 两个模块来供用户完成数据增强操作,两者的对比如下:
| 模块名称 | 实现 | 优缺点 |
| c_transforms | 基于C++的OpenCV实现 | 性能较高 |
| py_transforms | 基于Python的PIL实现 | 性能较差,但是可以自定义增强函数 |
使用建议:如果不需要自定义增强函数,并且c_transforms中有对应的实现,建议使用c_transforms模块。
目前c_transforms接口包括两部分:mindspore.dataset.transforms.c_transforms和mindspore.dataset.vision.c_transforms。
使用方法:
1.定义好数据增强函数:把多个增强函数加入到一个list中,并调用Compose封装;
2.调用dataset.map()函数,将定义好的函数或算子作用于指定的数据列。
示例代码如下:
- import mindspore.dataset as ds
- import mindspore.dataset.vision.c_transforms as CV_transforms
- import mindspore.dataset.transforms.c_transforms as C_transforms
-
- DATA_DIR = "./cifar-10-batches-bin/"
- cifar_ds = ds.Cifar10Dataset(DATA_DIR, shuffle=True, usage='train')#定义增强函数列表
- transforms_list = C_transforms.Compose[CV_transforms.RandomCrop((32, 32), (4, 4, 4, 4)), CV_transforms.RandomHorizontalFlip(), CV_transforms.Rescale(rescale, shift), CV_transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), CV_transforms.HWC2CHW()]#调用map()函数
- cifar_ds = cifar_ds.map(operations=transforms_list, input_columns="image")
其中,input_columns为指定要做增强的数据列,operations为定义的增强函数。
py_transforms接口也包括两部分mindspore.dataset.transforms.py_transforms和mindspore.dataset.vision.py_transforms。
使用方法:和c_transforms模块中的使用方法类似。示例代码如下:
- import mindspore.dataset as ds
- import mindspore.dataset.vision.py_transforms as py_vision
- import mindspore.dataset.transforms.py_transforms as py_transforms
-
- DATA_DIR = "./cifar-10-batches-bin/"
- cifar_ds = ds.Cifar10Dataset(DATA_DIR, shuffle=True, usage='train')
- transform_list = py_transforms.Compose([
- py_vision.ToPIL(),
- py_vision.RandomCrop((32, 32), (4, 4, 4, 4)),
- py_vision.RandomHorizontalFlip(),
- py_vision.ToTensor(),
- py_vision.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
- cifar_ds = cifar_ds.map(operations=transforms_list, input_columns="image")
使用py_transforms自定义增强函数:
自定义增强函数可参考MindSpore源码中的py_transforms_util.py脚本。下面以RandomBrightness为例,说明自定义增强算子的定义方式:
- #自定义增强函数定义class RandomBrightness(object):
- """
- Randomly adjust the brightness of the input image.
- Args:
- brightness (float): Brightness adjustment factor (default=0.0).
- Returns:
- numpy.ndarray, image.
- """
- def __init__(self, brightness=0.0):
- self.brightness = brightness
- def __call__(self, img):
- alpha = random.uniform(-self.brightness, self.brightness)
- return (1-alpha) * img
自定义算子的调用和py_transforms_util.py中的算子调用没有区别。
数据处理操作有:zip、shuffle、map、batch、repeat。
| 数据处理操作 | 说明 |
| zip | 合并多个数据集 |
| shuffle | 混洗数据 |
| map | 将函数和算子作用于指定列数据 |
| batch | 将数据分批,每次迭代返回一个batch的数据 |
| repeat | 对数据集进行复制 |
一般训练过程中都会用到shuffle、map、batch、repeat,如下示例:
- import mindspore.dataset as ds
- import mindspore.dataset.vision.c_transforms as CV_transforms
- import mindspore.dataset.transforms.c_transforms as C_transforms
-
- DATA_DIR = "./cifar-10-batches-bin/"
- cifar_ds = ds.Cifar10Dataset(DATA_DIR, shuffle=True, usage='train')
- transform_list = C.Compose([
- CV.RandomCrop((32, 32), (4, 4, 4, 4)),
- CV.RandomHorizontalFlip(),
- CV.Rescale(rescale, shift),
- CV.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
- CV.HWC2CHW()])# map()
- cifar_ds.map(input_columns="image", operations=transforms_list)# batch()
- cifar_ds = cifar_ds.batch(batch_size, drop_remainder=True)# repeat()
- cifar_ds = cifar_ds.repeat(repeat_num)
在实际使用过程中,需要组合使用这几个操作时,为达到最优性能,推荐按照如下顺序: 数据集加载并shuffle -> map -> batch -> repeat。
以下简单介绍一下数据处理函数的使用方法:
方式一:加载数据集时shuffle
- import mindspore.dataset as ds
-
- DATA_DIR = "./cifar-10-batches-bin/"
- cifar_ds = ds.Cifar10Dataset(DATA_DIR, shuffle=True, usage='train')
方式二:加载数据集后shuffle
- import mindspore.dataset as ds
-
- DATA_DIR = "./cifar-10-batches-bin/"
- cifar_ds = ds.Cifar10Dataset(DATA_DIR, usage='train')
- cifar_ds = cifar_ds.shuffle(buffer_size=10000)
参数说明:
buffer_size:buffer_size越大,混洗程度越大,时间消耗更大
- func = lambda x : x*2
- cifar_ds = cifar_ds.map(input_columns="data", operations=func)
参数说明:
input_columns:函数作用的列数据
operations:对数据做操作的函数
cifar_ds = cifar_ds.batch(batch_size=32, drop_remainder=True, num_parallel_workers=4)
参数说明:
drop_remainder:舍弃最后不完整的batch
num_parallel_workers: 用几个线程来读取数据
cifar_ds = cifar_ds.repeat(count=2)
参数说明:
count: 数据集复制数量
- import mindspore.dataset as ds
-
- DATA_DIR_1 = "custom_dataset_dir_1/"
- DATA_DIR_2 = "custom_dataset_dir_2/"
- imagefolder_dataset_1 = ds.ImageFolderDatasetV2(DATA_DIR_1)
- imagefolder_dataset_2 = ds.ImageFolderDatasetV2(DATA_DIR_2)
- imagefolder_dataset = ds.zip((imagefolder_dataset_1, imagefolder_dataset_2))
详细代码请前往MindSpore论坛进行下载:华为云论坛_云计算论坛_开发者论坛_技术论坛-华为云
说明:严禁转载本文内容,否则视为侵权。