• pytorch:dataloader自定义数据集制作


    1.如何自定义数据集

    • 1.数据和标签的目录结构先搞定
    • 2.写好读取数据和标签路径的函数(根据自己数据集情况来写)
    • 3.完成单个数据与标签读取函数

     

    我们以用txt文件指定数据路径与标签为实际例子

            数据集情况,train_filelist存入训练集图片,val_filelist存入验证集图片,train.txt、val.txt分别存入图片名称及标签

     

    读取txt文件中的路径和标签 

    首先,从文件中读取图片名称及标签,先暂时存储为字典结构 

    1. def load_annotations(ann_file):
    2.     data_infos = {}
    3.     with open(ann_file) as f:
    4.         samples = [x.strip().split(' ') for x in f.readlines()]
    5.         for filename, gt_label in samples:
    6.             data_infos[filename] = np.array(gt_label, dtype=np.int64)
    7.     return data_infos 

    分别将图片名及标签存入列表中 

     

    写入完整图片路径

     用dataloader实现 

    • 1.注意要使用from torch.utils.data import Dataset, DataLoader
    • 2.类名定义class FlowerDataset(Dataset),其中FlowerDataset可以改成自己的名字
    • 3.def init(self, root_dir, ann_file, transform=None):咱们要根据自己任务重写
    • 4.def getitem(self, idx):根据自己任务,返回图像数据和标签数据

    其中,getitem中的idx表示,将数据shuffle后,取的索引

    1. from torch.utils.data import Dataset, DataLoader
    2. class FlowerDataset(Dataset):
    3. def __init__(self, root_dir, ann_file, transform=None):
    4. self.ann_file = ann_file
    5. self.root_dir = root_dir
    6. self.img_label = self.load_annotations()
    7. self.img = [os.path.join(self.root_dir,img) for img in list(self.img_label.keys())]
    8. self.label = [label for label in list(self.img_label.values())]
    9. self.transform = transform
    10. def __len__(self):
    11. return len(self.img)
    12. def __getitem__(self, idx):
    13. image = Image.open(self.img[idx])
    14. label = self.label[idx]
    15. if self.transform:
    16. image = self.transform(image)
    17. label = torch.from_numpy(np.array(label))
    18. return image, label
    19. def load_annotations(self):
    20. data_infos = {}
    21. with open(self.ann_file) as f:
    22. samples = [x.strip().split(' ') for x in f.readlines()]
    23. for filename, gt_label in samples:
    24. data_infos[filename] = np.array(gt_label, dtype=np.int64)
    25. return data_infos

    数据预处理(transform)

    • 预处理的事都在上面的getitem中完成,返回的数据和标签就是建模时模型的输入和损失函数中标签的输入
    1. data_transforms = {
    2. 'train':
    3. transforms.Compose([
    4. transforms.Resize(64),
    5. transforms.RandomRotation(45),#随机旋转,-45到45度之间随机选
    6. transforms.CenterCrop(64),#从中心开始裁剪
    7. transforms.RandomHorizontalFlip(p=0.5),#随机水平翻转 选择一个概率概率
    8. transforms.RandomVerticalFlip(p=0.5),#随机垂直翻转
    9. transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),#参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
    10. transforms.RandomGrayscale(p=0.025),#概率转换成灰度率,3通道就是R=G=B
    11. transforms.ToTensor(),
    12. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])#均值,标准差
    13. ]),
    14. 'valid':
    15. transforms.Compose([
    16. transforms.Resize(64),
    17. transforms.CenterCrop(64),
    18. transforms.ToTensor(),
    19. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    20. ]),
    21. }

    根据写好的class FlowerDataset(Dataset):来实例化dataloader 

    • 1.构建数据集:分别创建训练和验证用的数据集(如果需要测试集也一样的方法)
    • 2.用Torch给的DataLoader方法来实例化(batch啥的自己定,根据你的显存来选合适的)
    • 3.打印看看数据的情况

    检验一下数据 

    1. image, label = iter(train_loader).next()
    2. sample = image[0].squeeze()
    3. sample = sample.permute((1, 2, 0)).numpy() # 将channels从第一个维度变为第三个维度,gpu要加一个.cuda.cpu方法
    4. sample *= [0.229, 0.224, 0.225]
    5. sample += [0.485, 0.456, 0.406]
    6. plt.imshow(sample)
    7. plt.show()
    8. print('Label is: {}'.format(label[0].numpy())) 

     

     最后传入模型进行训练就可以了

     

     

     

     

     

     

     

  • 相关阅读:
    2024年第16周技术复盘
    测试域: 流量回放-介绍篇
    如何解决pc端屏幕显示缩放比例125%,150%对页面布局的影响
    青龙面板-快手极速版(每天3块脚本)(废-已不能使用)
    Alibaba官方上线,SpringBoot+SpringCloud全彩指南(第五版)
    与导师沟通2023-09-14
    使用kube-bench检测Kubernetes集群安全
    .Net Core/.net 6/.Net 8 实现Mqtt服务器
    LQ0139 油漆面积【枚举】
    dubbo分布式日志调用链追踪
  • 原文地址:https://blog.csdn.net/qq_52053775/article/details/126191617