• PyTorch入门之【dataset】


    参考:https://www.bilibili.com/video/BV1DV4y1y7KG/?spm_id_from=333.999.0.0&vd_source=98d31d5c9db8c0021988f2c2c25a9620

    使用Pytorch自带的dataset

    在 PyTorch 中,torchvision.datasets 包中提供了许多经典数据集的实现,你可以使用它们来训练和测试模型。
    当然这些数据集是在服务器上的它在使用的时候是联网下载的。首次运行会下载,再次运行就不用下载了。
    这里以经典的MNIST数据集为例。
    总代码如下:

    import torch
    from torchvision import datasets
    import matplotlib.pyplot as plt
    from torch.utils.data import DataLoader
    from torchvision import transforms
    
    # define a transform
    transform = transforms.Compose([
        transforms.Resize(24),
        transforms.RandomRotation(10),
        transforms.ToTensor()
    ])
    
    # download training & testing dataset
    training_data = datasets.MNIST(
        root='data',
        train=True,
        download=True,
        transform=transform
    )
    
    test_data = datasets.MNIST(
        root='data',
        train=False,
        download=True,
        transform=transform
    )
    
    # create label to idx dictionary
    labels = {i: training_data.classes[i] for i in range(len(training_data.classes))}
    
    # display images in MNIST
    figure = plt.figure(figsize=(8, 8))
    cols, rows = 3, 3
    for i in range(1, cols * rows + 1):
        sample_idx = torch.randint(len(training_data), size=(1,)).item()
        img, label = training_data[sample_idx]
        figure.add_subplot(rows, cols, i)
        plt.title(labels[label])
        plt.axis("off")
        plt.imshow(img.squeeze(), cmap="gray")
    plt.show()
    
    # create dataloader
    train_data_loader = DataLoader(training_data, batch_size=16, shuffle=True)
    test_data_loader = DataLoader(test_data, batch_size=16, shuffle=True)
    print(next(iter(train_data_loader))[0].shape)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47

    下面挨个看各个模块的作用:

    # define a transform
    transform = transforms.Compose([
        transforms.Resize(24),
        transforms.RandomRotation(10),
        transforms.ToTensor()
    ])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    这段代码定义了一个数据转换管道,它将一系列的图像处理操作串联起来,以便对图像进行预处理。

    • transforms.Grayscale():将彩色图像转换为灰度图像。
    • transforms.Resize(24):调整图像的大小为 24x24 像素。
    • transforms.RandomRotation(10):随机旋转图像最多 10 度,增加数据的多样性和鲁棒性。
    • transforms.ToTensor():将图像转换为张量形式,以便进行后续的数据处理和模型训练。

    通过将上述操作按照顺序组合在一起,你可以定义一个 transform 对象,用于对图像数据集中的每个图像进行预处理。该 transform 对象被用于加载 MNIST 数据集,并且在 DataLoader 中配合使用。这样的数据预处理流程在深度学习中非常常见,它能够帮助提高模型训练的效果和泛化能力。你可以根据自己的需求,定制不同的转换操作,以适应不同的任务和数据集特点。

    # download training & testing dataset
    training_data = datasets.MNIST(
        root='data',
        train=True,
        download=True,
        transform=transform
    )
    
    test_data = datasets.MNIST(
        root='data',
        train=False,
        download=True,
        transform=transform
    )
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14

    上述代码就是下载training_data和test_data数据。
    download=True 参数用于指定是否下载数据集。当该参数设置为 True 时,如果数据集尚未下载,则会自动下载数据集。如果数据集已经存在,将不会再次下载。在加载数据集时 datasets.MNIST() 会检查文件是否下载过。

    # create label to idx dictionary
    labels = {i: training_data.classes[i] for i in range(len(training_data.classes))}
    
    • 1
    • 2

    这段代码的作用是将 MNIST 训练集的类别标签映射为整数索引,并将其存储在 labels 字典中。
    这个MNIST 训练集是用来区分0-9的数据集,故这里就可以将0映射到0,1映射到1以此类推。

    # display images in MNIST
    figure = plt.figure(figsize=(8, 8))
    cols, rows = 3, 3
    for i in range(1, cols * rows + 1):
        sample_idx = torch.randint(len(training_data), size=(1,)).item()
        img, label = training_data[sample_idx]
        figure.add_subplot(rows, cols, i)
        plt.title(labels[label])
        plt.axis("off")
        plt.imshow(img.squeeze(), cmap="gray")
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11

    上述代码就是将MNIST数据集中随机的生成9个图片打印出来,为了验证一下我们的MNIST数据集是否成功的加载

    # create dataloader
    train_data_loader = DataLoader(training_data, batch_size=16, shuffle=True)
    test_data_loader = DataLoader(test_data, batch_size=16, shuffle=True)
    print(next(iter(train_data_loader))[0].shape)
    
    • 1
    • 2
    • 3
    • 4

    上述代码用于创建数据加载器 (DataLoader),设置批次以及是否shuffle。

    用户自定义的dataset

    import torch
    import matplotlib.pyplot as plt
    from torch.utils.data import DataLoader
    from torchvision import transforms
    from torchvision.datasets import ImageFolder
    
    
    # define a transform
    transform = transforms.Compose([
        transforms.Grayscale(),
        transforms.Resize(24),
        transforms.RandomRotation(10),
        transforms.ToTensor()
    ])
    
    # create dataset
    my_mnist = ImageFolder(root='./my-mnist', transform=transform)
    
    # create label to idx dictionary
    labels = {i: my_mnist.classes[i] for i in range(len(my_mnist.classes))}
    
    # display images in MNIST
    figure = plt.figure(figsize=(8, 8))
    cols, rows = 3, 3
    for i in range(1, cols * rows + 1):
        sample_idx = torch.randint(len(my_mnist), size=(1,)).item()
        img, label = my_mnist[sample_idx]
        figure.add_subplot(rows, cols, i)
        plt.title(labels[label])
        plt.axis("off")
        plt.imshow(img.squeeze(), cmap="gray")
    plt.show()
    
    # create dataloader
    train_data_loader = DataLoader(my_mnist, batch_size=16, shuffle=True)
    print(next(iter(train_data_loader))[0].shape)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36

    总的代码几乎差不多,唯一有区别的就是数据是从自己定义的路径下加载的。
    使用 ImageFolder 类创建数据集 my_mnist

  • 相关阅读:
    提升程序运行速度-计算加速的20种方法
    SQL(Structured Query Language)—结构化查询语言
    实现高效消息传递:使用RabbitMQ构建可复用的企业级消息系统
    第0次 序言
    15. Canvas 和 SVG 的区别?
    MySQL主从复制与读写分离
    前端学习笔记005:数据传输 + AJAX + axios
    Flink SQL管理平台flink-streaming-platform-web安装搭建
    31.springboot中的注解总结(spring,springmvc,springboot,mybatis,dubbo)
    zookeeper选举机制详解
  • 原文地址:https://blog.csdn.net/qq_46527915/article/details/133611104