• Pytorch:Dataset类和DataLoader类


      在机器学习和深度学习框架中,尤其是在 PyTorch 中,DatasetDataLoader 是处理和加载数据的重要工具。这里我们详细探讨这两个类的结构、用途和如何实际使用它们。
      数据集(Dataset)是指存储和表示数据的类或接口。它通常用于封装数据,以便能够在机器学习任务中使用。数据集可以是任何形式的数据,比如图像、文本、音频等。数据集的主要目的是提供对数据的标准访问方法,以便可以轻松地将其用于模型训练、验证和测试。
      数据加载器(DataLoader)是一个提供批量加载数据的工具。它通过将数据集分割成小批量,并按照一定的顺序加载到内存中,以提高训练效率。数据加载器常用于训练过程中的数据预处理、批量化操作和数据并行处理等。

    一、Dataset 类

    1、定义

    Dataset 是一个抽象类,用于表示一个数据集的全部内容。在 PyTorch 中,任何继承自 torch.utils.data.Dataset 的自定义数据集需要实现两个必须的方法:

    • __getitem__(self, index)
      • 这个方法应该返回一个索引处的数据点和其对应的标签。例如,在图像数据集中,这可能是一对(图像,标签)。
    • __len__(self)
      • 这个方法返回数据集中的数据点的总数,即数据集的大小。

    2、示例

    下面是一个简单的形象化例子,展示如何创建一个用于加载图像数据集的自定义 Dataset 类:

    import torch
    from torch.utils.data import Dataset
    class IceCreamDataset(Dataset):
        def __init__(self):
            self.flavors = ["vanilla", "chocolate", "strawberry"]
    
        def __len__(self):
            return len(self.flavors)
    
        def __getitem__(self, index):
            return f"One scoop of {self.flavors[index]} ice cream"
    ice_cream_menu = IceCreamDataset()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12

    在这个例子中,IceCreamDataset 类定义了一个冰激凌数据。

    二、DataLoader 类

    1、定义

    DataLoader 是一个迭代器,用于将 Dataset 封装成易于访问的数据流,支持批量加载和多进程数据加载等操作。

    2、参数

    • dataset: 要加载的 Dataset 对象。
    • batch_size(可选): 每个批次加载的样本数量。即对Dataset数据集进行等分,分成的份数(每份叫作一个batch)为len(dataset)/batch_sizebatch_size通常是单次训练使用的数据量,默认为1。
    • shuffle(可选): 是否在每个训练周期开始时打乱数据。
    • num_workers(可选): 用于数据加载的进程数。

    3、示例:使用 DataLoader

    一旦定义了 Dataset,就可以使用 DataLoader 来有效地加载数据:

    from torch.utils.data import DataLoader
    
    # 创建 DataLoader,每批三份不同口味的冰激凌
    ice_cream_loader = DataLoader(ice_cream_menu)#等价于ice_cream_loader = DataLoader(ice_cream_menu,batch_size=1)
    
    for batch in ice_cream_loader:
        print(batch)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    在这个例子中,data_loader 会自动管理从 dataset 中加载数据的复杂性,如批量加载、打乱顺序和多进程加载。
    输出:

    ['One scoop of vanilla ice cream']
    ['One scoop of chocolate ice cream']
    ['One scoop of strawberry ice cream']
    
    • 1
    • 2
    • 3
    ice_cream_loader = DataLoader(ice_cream_menu,batch_size=2)
    
    • 1

    输出:

    ['One scoop of vanilla ice cream', 'One scoop of chocolate ice cream']
    ['One scoop of strawberry ice cream']
    
    • 1
    • 2
    ice_cream_loader = DataLoader(ice_cream_menu,batch_size=3)#大于等于3的输出一样,因为就三个数据了
    
    • 1
    ['One scoop of vanilla ice cream', 'One scoop of chocolate ice cream', 'One scoop of strawberry ice cream']
    
    • 1

    三、总结

    通过组合使用 DatasetDataLoader,PyTorch 用户可以高效、灵活地处理大规模数据集。Dataset 提供了一个清晰的接口来访问单个数据点__getitem__),而 DataLoader 管理整个数据集的批量处理和并行加载,这两者的结合极大地简化了在训练深度学习模型时的数据处理工作。

    为了简单说明,以下我们将继承Dataset类的类,说成Dataset
    根据上述简单的例子,我们可以知道,Dataset可以用来导入数据集,并规定整个数据集的长度是如何计算的,并规定单个数据点的格式;而DataLoader配合Dataset使用,可以导入数据集,并规定该数据集划分的批次数量和批次大小,以及导入数据集时是否打乱数据等。


    对于:

    for batch in dataloader:
    	pass
    
    • 1
    • 2

    为了理解batchbatch_size,可以这样去想:
      假设有512个箱子,将这些箱子,每16个分成一份,一共有32份,每一份叫作一个batch,而每个batch里面一共16个箱子。每16个箱子为一批,一批一批进行拆箱,即一个batch一个batch进行处理。遍历dataloader,每次取出的是一个batch,从上面的例子可以发现,batch里面的元素是通过列表组织在一起的。

      每一个batch实际上就是DataLoaderDataset划分成的一个批次,每个batch的大小就是batch_size(除非数据集不是它的整数倍,上面也有体现)。所有batch加起来才构成整个Dataset
      如果是图片数据集,batch_size可以认为,一个batchbatch_size张图片(如果该数据集规定单个数据点是一张图片的话。)(因为DataLoader访问数据时,会按照Dataset规定的数据点规格访问)。

    四、实战

    以上是一个简单的实例,方便理解,现在我们进行实战。

    import torch
    from sklearn.datasets import load_iris
    from torch.utils.data import Dataset, DataLoader
     
    # 此函数用于加载鸢尾花数据集
    def load_data(shuffle=True):
        x = torch.tensor(load_iris().data)
        y = torch.tensor(load_iris().target)
     
        # 数据归一化
        x_min = torch.min(x, dim=0).values
        x_max = torch.max(x, dim=0).values
        x = (x - x_min) / (x_max - x_min)
     
        if shuffle:
            idx = torch.randperm(x.shape[0])
            x = x[idx]
            y = y[idx]
        return x, y
     
    # 自定义鸢尾花数据类
    class IrisDataset(Dataset):
        def __init__(self, mode='train', num_train=120, num_dev=15):
            super(IrisDataset, self).__init__()
            x, y = load_data(shuffle=True)
            if mode == 'train':
                self.x, self.y = x[:num_train], y[:num_train]
            elif mode == 'dev':
                self.x, self.y = x[num_train:num_train + num_dev], y[num_train:num_train + num_dev]
            else:
                self.x, self.y = x[num_train + num_dev:], y[num_train + num_dev:]
     
        def __getitem__(self, idx):
            return self.x[idx], self.y[idx]
     
        def __len__(self):
            return len(self.x)
     
    batch_size = 16
     
    # 分别构建训练集、验证集和测试集
    train_dataset = IrisDataset(mode='train')
    dev_dataset = IrisDataset(mode='dev')
    test_dataset = IrisDataset(mode='test')
     
    train_loader = DataLoader(train_dataset, batch_size=batch_size,shuffle=True)
    dev_loader = DataLoader(dev_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48

    这段代码涉及到使用 PyTorch 加载和处理著名的鸢尾花(Iris)数据集,并将其分成训练集、验证集和测试集。下面逐部分详细解释:

    1、load_data函数:

    1. 加载数据:

      • 使用 load_iris() 函数从 scikit-learn 库中加载鸢尾花数据集。这个函数返回包含特征(data)和目标(target)的数据结构。
      • 数据转换成 PyTorch 张量,方便后续使用 PyTorch 进行操作。
    2. 归一化:

      • 对特征进行归一化处理,使得每个特征的值范围都缩放到 [0, 1] 区间。这是通过从每个特征中减去最小值,然后除以其范围(最大值 - 最小值)来实现的。
      • 归一化有助于模型训练,因为它确保了所有特征都在相同的尺度上,从而加速学习过程。
    3. 打乱数据:

      • 如果启用 shuffle,则通过生成一个随机排列的索引并重新排序数据来打乱数据集。这通常用于训练数据集,以保证每次训练的随机性和泛化能力。
      • 这里使用的方法:
        • idx = torch.randperm(x.shape[0])x.shape[0]是二维张量的行数。 torch.randperm即随机打乱(生成一个 0 到样本数量减一的随机排列),得到一个随机排列。
        • x = x[idx];y = y[idx],使用的是高级索引:使用多个整数索引访问多个元素

    2、IrisDataset类

    • IrisDataset 类继承自 Dataset。它用于封装鸢尾花数据,使其可以通过 PyTorch DataLoader 使用。
    • 在构造函数中,根据 mode(训练、验证或测试)来划分数据:
      • 训练集 (train): 使用数据集的前 num_train 个样本。
      • 验证集 (dev): 紧随训练集之后的 num_dev 个样本。
      • 测试集 (test): 剩余的样本。
    • 这种方式的好处是简单易实现,但在实际应用中可能需要更复杂的交叉验证策略来更好地评估模型。

    3、DataLoader 的使用

    • 对于每种数据集(训练、验证、测试),通过创建 DataLoader 实例来进行封装。这允许以批量方式加载数据,可选择是否打乱。
    • 批量大小 (batch_size):
      • 对于训练数据,使用较大的批量(例如 16),有助于稳定和加速训练过程。
      • 对于验证数据,也采用同样大小的批量,以保持一致性。
      • 对于测试数据,每批只有一个样本,这常用于评估模型时逐个样本进行处理。
  • 相关阅读:
    readv、io_uring、liburing and command cat
    Linux系统配置静态IP地址步骤
    2022微信小程序:解决消除button的默认样式后容器任然在中间的位置的方案
    【关于ROS_PACKAGE_PATH的含义、理解和用法】
    可持久化01Trie
    unity 使用模拟器进行Profiler性能调试
    2.C语言中常见的关键字
    六、软考-系统架构设计师笔记-软件工程基础知识
    算法与设计分析--分治算法的设计与分析
    Failed assertion: line 4349 pos 12: ‘!_dirty‘: is notflutter: true.
  • 原文地址:https://blog.csdn.net/m0_63997099/article/details/138066157