• Dataset和DataLoader用法


    Dataset和DataLoader用法

    在d2l中有简洁的加载固定数据的方式,如下

    d2l.load_data_fashion_mnist()
    # 源码
    Signature: d2l.load_data_fashion_mnist(batch_size, resize=None)
    Source:   
    def load_data_fashion_mnist(batch_size, resize=None):
        """Download the Fashion-MNIST dataset and then load it into memory.
    
        Defined in :numref:`sec_fashion_mnist`"""
        trans = [transforms.ToTensor()]
        if resize:
            trans.insert(0, transforms.Resize(resize))
        trans = transforms.Compose(trans)
        mnist_train = torchvision.datasets.FashionMNIST(
            root="../data", train=True, transform=trans, download=True)
        mnist_test = torchvision.datasets.FashionMNIST(
            root="../data", train=False, transform=trans, download=True)
        return (data.DataLoader(mnist_train, batch_size, shuffle=True,
                                num_workers=get_dataloader_workers()),
                data.DataLoader(mnist_test, batch_size, shuffle=False,
                                num_workers=get_dataloader_workers()))
    File:      ~/anaconda3/envs/d2l/lib/python3.9/site-packages/d2l/torch.py
    Type:      function
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22

    如果我们要自定义需要加载的数据集

    数据集:一个图片文件夹,用csv文件来表示训练数据和标签

    # 定义Dataset
    import pandas as pd
    import os
    from PIL import Image
    from torch.utils.data import Dataset, DataLoader
    
    from sklearn.model_selection import train_test_split
    from sklearn.preprocessing import LabelEncoder
    import torchvision.transforms as transforms
    
    class CustomDataset(Dataset):
        def __init__(self, csv_file, root_dir, transform=None):
            self.data = pd.read_csv(csv_file) 
            self.root_dir = root_dir
            self.transform = transform
            label_encoder = LabelEncoder()
            self.labels = label_encoder.fit_transform(self.data['label'])
            
        def __len__(self):
            return len(self.data)
        
        def __getitem__(self, idx):
            img_name = os.path.join(self.root_dir, self.data.iloc[idx, 0])
            # 读取图片并做增广
            image = Image.open(img_name)
            if self.transform is not None:
                image = self.transform(image)
            # 将数字转换成独热编码的张量(记得转换成float)
            label = F.one_hot(torch.tensor(self.labels[idx]), 		
            					num_classes=self.data['label'].nunique()).float()
            return image, label
    
    # 定义参数和超参数训练
    batch_size = 256
    lr = num_epoch = 0.9, 10
    
    # 加载数据
    sample = '/kaggle/input/classify-leaves/sample_submission.csv'
    ts_path = "/kaggle/input/classify-leaves/test.csv"
    tr_path = "/kaggle/input/classify-leaves/train.csv"
    image_path = '/kaggle/input/classify-leaves'
    
    dataset = CustomDataset(csv_file = sample, root_dir = image_path, transform=transform_train)
    train_size = int(0.8 * len(dataset))
    valid_size = len(dataset) - train_size
    tr_dataset, te_dataset = torch.utils.data.random_split(dataset, [train_size, valid_size])
    
    tr_dataloader = DataLoader(tr_dataset, batch_size, shuffle=True)
    ts_dataloader = DataLoader(te_dataset, batch_size, shuffle=False)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49

    总结

    需要将__init__,len,__getitem__按照数据集和模型的要求,对应的编写好代码。

  • 相关阅读:
    C++11新特性nullptr
    图像处理与计算机视觉--第四章-图像滤波与增强-第二部分
    Pritunl搭建OpenVPN服务器详细流程,快速实现公网远程连接!
    [Java]剖析异常处理机制与常见面试题
    网络安全(黑客技术)自学规划
    经验分享|甘肃某中型灌区信息化管理平台案例
    HTML5 Canvas
    Linux压缩与解压缩命令
    安卓开发之环境配置
    农业信息化技术导论886笔记复习
  • 原文地址:https://blog.csdn.net/weixin_45363113/article/details/133245799