问题描述:
我准备自己写一下vgg16做三分类的网络,我在网上爬了一些图片,数据命令格式为:cat_0_0.jpg、dog_0_1.jpg、person_0_2.jpg...,中间一位数字表示图片序号,最后一位数字为标签label
整个数据集处理代码如下:
- import os
- import numpy as np
- from PIL import Image
- import mindspore.common.dtype as mstype
- import mindspore.dataset as ds
- import mindspore.dataset.transforms.c_transforms as C
- import mindspore.dataset.vision.c_transforms as vc
-
- class _dcp_Dataset:
- def __init__(self,img_root_dir,device_target="CPU"):
- if not os.path.exists(img_root_dir):
- raise RuntimeError(f"the input image dir {img_root_dir} is invalid")
- self.img_root_dir=img_root_dir
- self.img_names=[i for i in os.listdir(img_root_dir) if i.endswith(".jpg")]
- self.target=device_target
-
- def __len__(self):
- return len(self.img_names)
-
- def __getitem__(self, index):
- img_name=self.img_names[index]
- im=Image.open(os.path.join(self.img_root_dir,img_name))
- image=np.array(im)
- label_str=img_name.split("_")[-1]
- label_str=label_str.split(".")[0]
- label=np.array(label_str)
-
- return image,label
-
- def creat_dataset(dataset_path,batch_size=2,num_shards=1,shard_id=0,device_target="CPU"):
- dataset=_dcp_Dataset(dataset_path,device_target)
- data_set=ds.GeneratorDataset(dataset,["image","label"],shuffle=True,num_shards=1,shard_id=0)
- image_trans=[
- vc.Resize((224,224)),
- vc.RandomHorizontalFlip(),
- vc.Rescale(1/255,shift=0),
- vc.Normalize((0.4465, 0.4822, 0.4914), (0.2010, 0.1994, 0.2023)),
- vc.HWC2CHW
- ]
- label_trans=[C.TypeCast(mstype.int32)]
-
- data_set=data_set.map(operations=image_trans,input_columns=["image"])
- data_set=data_set.map(operations=label_trans,input_columns=["label"])
- # data_set=data_set.shuffle(buffer_size=batch_size)
- data_set=data_set.batch(batch_size=batch_size,drop_remainder=True)
- # data_set=data_set.repeat(1)
-
- return data_set
-
-
- if __name__ == '__main__':
- data=creat_dataset("./image_DCP")
- print(data)
- data_loader = data.create_dict_iterator()
- for i, data in enumerate(data_loader):
- print(i)
- print(data)
当我使用creat_dict_iterator()迭代数据的时候报错:

请问社区的大佬有谁能帮忙看一下这是哪里的问题吗?
解答:
1. 从你的脚本发现:

这个需要修改成 vc.HWC2CHW() ,后面少了一个括号。
2. 如果有类似的错误,一般的调试方法如下:可以先保留 GeneratorDataset,然后再增加一个map操作,再增强一个map操作,这样就能确定是哪个环节出了问题。
3. 另一个,我们看看能不能添加下检验,判断下你这个出错的场景。
我的意思是怎么样确认是哪个数据处理步骤报错了。
例如:你有 xxDataset -> map -> map -> batch 这样的数据处理流程。
你可以按如下方式调试脚本:
1. 只保留 xxDataset,然后运行下脚本,查看是否报错;
2. 保留 xxDataset -> map,然后运行下脚本,查看是否报错;
3. 保留 xxDataset -> map -> map,然后运行下脚本,查看是否报错;
4. 保留 xxDataset -> map -> map -> batch,然后运行下脚本,查看是否报错;
按照上述的方法,就能确认是哪个map/batch出错了。
进一步,如果 map 操作是以 img_trans 列表的方式传入的,那么可以把 img_trans 中的操作进一步减少,就能确认是哪个具体的trans出错了。