• MindSpore处理自定义数据集的时候报错


    问题描述:

    我准备自己写一下vgg16做三分类的网络,我在网上爬了一些图片,数据命令格式为:cat_0_0.jpg、dog_0_1.jpg、person_0_2.jpg...,中间一位数字表示图片序号,最后一位数字为标签label

    整个数据集处理代码如下:

    1. import os
    2. import numpy as np
    3. from PIL import Image
    4. import mindspore.common.dtype as mstype
    5. import mindspore.dataset as ds
    6. import mindspore.dataset.transforms.c_transforms as C
    7. import mindspore.dataset.vision.c_transforms as vc
    8. class _dcp_Dataset:
    9.     def __init__(self,img_root_dir,device_target="CPU"):
    10.         if not os.path.exists(img_root_dir):
    11.             raise RuntimeError(f"the input image dir {img_root_dir} is invalid")
    12.         self.img_root_dir=img_root_dir
    13.         self.img_names=[i for i in os.listdir(img_root_dir) if i.endswith(".jpg")]
    14.         self.target=device_target
    15.     def __len__(self):
    16.         return len(self.img_names)
    17.     def __getitem__(self, index):
    18.         img_name=self.img_names[index]
    19.         im=Image.open(os.path.join(self.img_root_dir,img_name))
    20.         image=np.array(im)
    21.         label_str=img_name.split("_")[-1]
    22.         label_str=label_str.split(".")[0]
    23.         label=np.array(label_str)
    24.         return image,label
    25. def creat_dataset(dataset_path,batch_size=2,num_shards=1,shard_id=0,device_target="CPU"):
    26.     dataset=_dcp_Dataset(dataset_path,device_target)
    27.     data_set=ds.GeneratorDataset(dataset,["image","label"],shuffle=True,num_shards=1,shard_id=0)
    28.     image_trans=[
    29.         vc.Resize((224,224)),
    30.         vc.RandomHorizontalFlip(),
    31.         vc.Rescale(1/255,shift=0),
    32.         vc.Normalize((0.4465, 0.4822, 0.4914), (0.2010, 0.1994, 0.2023)),
    33.         vc.HWC2CHW
    34.     ]
    35.     label_trans=[C.TypeCast(mstype.int32)]
    36.     data_set=data_set.map(operations=image_trans,input_columns=["image"])
    37.     data_set=data_set.map(operations=label_trans,input_columns=["label"])
    38.     # data_set=data_set.shuffle(buffer_size=batch_size)
    39.     data_set=data_set.batch(batch_size=batch_size,drop_remainder=True)
    40.     # data_set=data_set.repeat(1)
    41.     return data_set
    42. if __name__ == '__main__':
    43.    data=creat_dataset("./image_DCP")
    44.    print(data)
    45.    data_loader = data.create_dict_iterator()
    46.    for i, data in enumerate(data_loader):
    47.         print(i)
    48.         print(data)

    当我使用creat_dict_iterator()迭代数据的时候报错:

    请问社区的大佬有谁能帮忙看一下这是哪里的问题吗?

    解答:

    1. 从你的脚本发现:

    这个需要修改成 vc.HWC2CHW() ,后面少了一个括号。

    2. 如果有类似的错误,一般的调试方法如下:可以先保留 GeneratorDataset,然后再增加一个map操作,再增强一个map操作,这样就能确定是哪个环节出了问题。

    3. 另一个,我们看看能不能添加下检验,判断下你这个出错的场景。

    我的意思是怎么样确认是哪个数据处理步骤报错了。

    例如:你有 xxDataset -> map -> map -> batch 这样的数据处理流程。

    你可以按如下方式调试脚本:

    1. 只保留 xxDataset,然后运行下脚本,查看是否报错;

    2. 保留 xxDataset -> map,然后运行下脚本,查看是否报错;

    3. 保留 xxDataset -> map -> map,然后运行下脚本,查看是否报错;

    4. 保留 xxDataset -> map -> map -> batch,然后运行下脚本,查看是否报错;

    按照上述的方法,就能确认是哪个map/batch出错了。

    进一步,如果 map 操作是以 img_trans 列表的方式传入的,那么可以把 img_trans 中的操作进一步减少,就能确认是哪个具体的trans出错了。

  • 相关阅读:
    使用XLua在Unity中获取lua全局变量和函数
    【科普向】什么是CPU、什么是GPU?本机Win11的CPU和GPU配置如何
    IOS面试题object-c 51-60
    借助 ControlNet 生成艺术二维码 – 基于 Stable Diffusion 的 AI 绘画方案
    小红书怎么看关键词排名?如何提升笔记自然搜索排名
    聊一聊小程序单聊页面构思
    微信小程序返回上一页刷新组件数据
    vscode调试webpack项目的方法
    MyBatis学习:使用Like进行模糊查询,MyBatis怎么传参或者组装模糊条件
    Redis哨兵模式配置文件详解
  • 原文地址:https://blog.csdn.net/weixin_45666880/article/details/125622332