• prtorch.数据的导入与导出


    我们在整体代码中展示了怎么导入Fashion-MNIST这个dataset,这份数据包含六万个训练数据和一万个测试数据,每份数据都有一个28*28的灰度图,你可以把他们分成十个类。

    我们导入FashionMNIST Dataset需要使用到以下的参数

    root = 'data' 
    # 数据存储的路径
    
    train=True 
    # 是训练数据还是测试数据
    
    download=True 
    # 从互联网上下载数据 否则数据在根目录不存在
    
    transform = ToTensor 
    # 和target_transform一样,指定训练数据的转换器,将数据格式转换为tensor
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    import torch
    from torch.utils.data import Dataset
    from torchvision import datasets
    from torchvision.transforms import ToTensor
    import matplotlib.pyplot as plt
    
    
    training_data = datasets.FashionMNIST(
        root="data",
        train=True,
        download=True,
        transform=ToTensor()
    )
    
    test_data = datasets.FashionMNIST(
        root="data",
        train=False,
        download=True,
        transform=ToTensor()
    )
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20

    从training_data中拿到一组数据

    labels_map = {
        0: "T-Shirt",
        1: "Trouser",
        2: "Pullover",
        3: "Dress",
        4: "Coat",
        5: "Sandal",
        6: "Shirt",
        7: "Sneaker",
        8: "Bag",
        9: "Ankle Boot",
    }
    
    '''
    使用matplotlib来可视化训练的一些样本
    '''
    figure = plt.figure(figsize=(8, 8))
    cols, rows = 3, 3
    for i in range(1, cols * rows + 1):
        # 产生一个随机数
        sample_idx = torch.randint(len(training_data), size=(1,)).item()
        # 拿到一个随机的照片和照片的标签
        img, label = training_data[sample_idx]
        # 添加照片
        figure.add_subplot(rows, cols, i)
        # 添加照片的名称
        plt.title(labels_map[label])
        # 关闭坐标轴
        plt.axis("off")
        # 在相应位置绘图
        plt.imshow(img.squeeze(), cmap="gray")
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32

    创建我们自己的数据导入器,一个DataSet的类必须包含三个函数 __init__``__len__``__getitem__

    '''
    为文件创建自定义数据集 DataSet
    '''
    
    import os
    import pandas as pd
    from torchvision.io import read_image
    
    class CustomImageDataset(Dataset):
        # 这个函数在类实例化时初始化一次
        def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
            # 使用pd读取csv文件
            self.img_labels = pd.read_csv(annotations_file)
            # img的文件目录
            self.img_dir = img_dir
            self.transform = transform
            self.target_transform = target_transform
    
        # 返回img的数量,这个函数会被python自带的len()调用
        def __len__(self):
            return len(self.img_labels)
    
        # 返回第idx个数据
        def __getitem__(self, idx):
            # 获取img的路径
            img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
            # 使用read_img将img转换为tensor
            image = read_image(img_path)
            # 这个是数据的标签 也就是分类的实际结果
            label = self.img_labels.iloc[idx, 1]
            # 如果我们定义了transform 则使用transform来处理图像
            if self.transform:
                image = self.transform(image)
            # 如果我们定义了target_transform 则使用target_transform来处理图像的标签
            if self.target_transform:
                label = self.target_transform(label)
            # 返回图像和标签
            return image, label
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38

    使用我们自定义的DataSet

    '''
    DataSet每次返回一组数据(一个图像和一个对应的标签),但是在训练模型时我们希望一次导入多个数据,
    这里我们解释一下为什么一次我们经常看到导入数据量为64,32,100,而不是全部的数据呢,其实只要们的计算机内存和显卡内存足够大,是可以的,但是往往我们的计算机内存和显卡内存都是有限的,我们一次导入过多的数据就是out of memery,这样还好,但是你想想你要是训练了好几天out of memery 那不直接背过气去,所以这个值我们要设置的合理,既不要太大,也不要太小,太小的话会导致数据导入太慢
    并且每次导入数据都进行重新洗牌,来防止过拟合现象,并且使用多线程来加速导入的过程
    '''
    from torch.utils.data import DataLoader
    
    train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
    test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
    
    train_features, train_labels = next(iter(train_dataloader))
    print(f"Feature batch shape: {train_features.size()}")
    print(f"Labels batch shape: {train_labels.size()}")
    img = train_features[0].squeeze()
    label = train_labels[0]
    plt.imshow(img, cmap="gray")
    plt.show()
    print(f"Label: {label}")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18

    我们还发现,现在有两个参数我们是没有自己设置的,就是transform和target_transform

    torchvision.Transforms模块提供了几个开箱即用的常用转换

    import torch
    from torchvision import datasets
    from torchvision.transforms import ToTensor, Lambda
    
    ds = datasets.FashionMNIST(
        root="data",
        train=True,
        download=True,
        transform=ToTensor(),
        # 将整数转换为一个one-hot编码的tensor
        target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
    )
    
    for i in range(10):
        print(torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(i), value=1))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
  • 相关阅读:
    八大排序总是忘?快来这里~
    学生静态HTML个人博客主页【Web大学生网页作业成品】HTML+CSS+JavaScript
    verilog刷题:LFSR 产生特定序列
    算法通过村第七关-树(递归/二叉树遍历)白银笔记|递归实战
    什么是测试界天花板,我今天算是见到了
    西工大&ANU&CSIRO&IIAI提出基于排序的伪装目标检测网络RankNet,并提供了最大的COD数据集!...
    SpringMVC笔记
    利用已存在的conda环境
    天池Python练习07-字符串
    Monaco Editor教程(十九):编辑器自动完成建议项CompletionItem的配置详解
  • 原文地址:https://blog.csdn.net/weixin_43903639/article/details/126907953