• 【李沐深度学习笔记】图片分类数据集


    课程地址和说明

    图片分类数据集p3
    本系列文章是我学习李沐老师深度学习系列课程的学习笔记,可能会对李沐老师上课没讲到的进行补充。本文还参考了【李沐3】3.5、图像分类数据集

    图片分类数据集

    MNIST数据集是图像分类中广泛使用的数据集之一,但作为基准数据集过于简单。我们将使用类似但更复杂的Fashion-MNIST数据集

    导入数据集

    %matplotlib inline
    import torch
    import torchvision
    from torch.utils import data
    from torchvision import transforms
    import os
    os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" # jupyter notebook运行d2l内核挂掉解决方法
    from d2l import torch as d2l
    # 设置中文字体
    from pylab import mpl
    mpl.rcParams["font.sans-serif"] = ["SimHei"]   # 设置显示中文字体
    mpl.rcParams["axes.unicode_minus"] = False   # 设置正常显示符号
    d2l.use_svg_display()
    # 导入必要的库和模块
    # - torch:PyTorch库,用于构建和训练神经网络
    # - torchvision:PyTorch中用于处理图像数据的库
    # - torch.utils.data:PyTorch中用于处理数据加载的模块
    # - torchvision.transforms:用于定义和应用数据转换的模块
    # - d2l.torch:Dive into Deep Learning(《动手深度学习》)书中提供的PyTorch实用函数和工具
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19

    通过框架中的内置函数将Fashion-MNIST数据集下载并读取到内存中

    首先下载数据集

    # 通过ToTensor实例将图像数据从PIL类型变换成32位浮点数格式,
    # 并除以255使得所有像素的数值均在0-1之间
    trans = transforms.ToTensor()
    
    # 创建FashionMNIST数据集的训练集实例
    mnist_train = torchvision.datasets.FashionMNIST(
        root="../data",  # 数据集存放的根目录(根目录的上一级目录(父级目录)下的data文件夹)
        train=True,      # 表示加载训练集
        transform=trans, # 数据变换,将图像数据转换为Tensor格式并归一化
        download=True    # 是否下载数据集(如果尚未下载的话)
    )
    # 创建FashionMNIST数据集的测试集实例
    mnist_test = torchvision.datasets.FashionMNIST(
        root="../data",  # 数据集存放的根目录
        train=False,     # 表示加载测试集(测试数据集是用来验证模型好坏的数据集)
        transform=trans, # 数据变换,将图像数据转换为Tensor格式并归一化
        download=True    # 是否下载数据集(如果尚未下载的话)
    )
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19

    运行结果:

    # 查看一下数据集和训练集的个数
    print(len(mnist_train))
    print(len(mnist_test))
    
    • 1
    • 2
    • 3

    运行结果:
    60000
    10000

    # 查看数据集中一个图片的大小和颜色通道数(黑白图片,通道数为1)
    print(mnist_train[0][0].shape)
    
    • 1
    • 2

    运行结果:
    torch.Size([1, 28, 28])

    展示样本

    # 将标签从数字转成具体的文字
    def get_fashion_mnist_labels(labels):
        """
        返回Fashion-MNIST数据集的文本标签
        参数:
            labels: 包含数值标签的列表或数组
        返回:
            包含对应文本标签的列表
        """
        text_labels = ['T恤衫', '裤子', '套头衫', '裙子', '大衣',
                       '凉鞋', '衬衫', '运动鞋', '包', '短靴']
        # 比如1表示苹果,这里以前标记的是1,现在转换为苹果
        return [text_labels[int(i)] for i in labels]
    
    def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5): #@save
        """
        绘制图像列表
        参数:
            imgs: 包含图像的列表
            num_rows: 图像展示的行数
            num_cols: 图像展示的列数
            titles: 可选参数,图像标题的列表
            scale: 可选参数,控制图像的缩放比例
        返回:
            无返回值,显示绘制的图像
        """
        # 计算绘图区域的尺寸
        figsize = (num_cols * scale, num_rows * scale)
        # 创建一个具有指定尺寸的子图
        _, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
        # 将子图数组展平,以便逐个访问每个子图
        axes = axes.flatten()
        
        # 遍历图像列表并在每个子图中绘制图像
        for i, (ax, img) in enumerate(zip(axes, imgs)):
            if torch.is_tensor(img):
                # 如果图像是PyTorch的张量,将其转换为NumPy数组并在子图上显示
                ax.imshow(img.numpy())
            else:
                # 如果图像是PIL图像,直接在子图上显示
                ax.imshow(img)
            # 隐藏子图的x轴和y轴
            ax.axes.get_xaxis().set_visible(False)
            ax.axes.get_yaxis().set_visible(False)
            if titles:
                # 如果提供了标题列表,设置当前子图的标题
                ax.set_title(titles[i])
        
        # 返回绘制的子图数组
        return axes
    # 大小为固定数字批量的数据,用next拿到第一个小批量
    x, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
    # 18个图片,28x28像素,2行9列,展示第一批训练集样本和其对应的标签
    show_images(x.reshape(18,28,28),2,9,titles=get_fashion_mnist_labels(y))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54

    运行结果:

    读取一小批量数据,大小为batch_size

    batch_size = 256
    def get_dataloader_workers(): #@save
        """使用4个进程来读取数据"""
        return 4
    
    # 创建训练数据迭代器
    train_iter = data.DataLoader(
        mnist_train,                       # 使用的数据集实例
        batch_size,                        # 每个批次的样本数量
        shuffle=True,                      # 是否在每个epoch前打乱数据顺序
        num_workers=get_dataloader_workers() # 用于加载数据的进程数量
    )
    # 计算运行时间(4进程读数据,根据自己的CPU核心数和性能自主决定用几个进程(几个核心))
    timer = d2l.Timer()
    for X,y in train_iter:
        continue
    print(f"{timer.stop():.2f}秒")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17

    运行结果:
    3.09秒

    定义load_data_fashion_mnist函数

    把上面的所有功能都整合到load_data_fashion_mnist函数中

    def load_data_fashion_mnist(batch_size, resize=None): #@save
        """
        下载Fashion-MNIST数据集,然后将其加载到内存中
        参数:
            batch_size: 批次大小,用于小批量训练
            resize: 可选参数,指定图像调整的大小
        返回:
            包含训练数据迭代器和测试数据集的元组
        """
        # 创建数据变换列表,将图像转换为Tensor格式
        trans = [transforms.ToTensor()]
        
        # 如果提供了resize参数,将图像调整大小添加到变换列表
        if resize:
            trans.insert(0, transforms.Resize(resize))
            
        # 将变换列表组合成一个组合变换
        trans = transforms.Compose(trans)
        
        # 创建FashionMNIST数据集的训练集实例
        mnist_train = torchvision.datasets.FashionMNIST(
            root="../data",              # 数据集存放的根目录
            train=True,                  # 表示加载训练集
            transform=trans,             # 数据变换,包括调整大小和转换为Tensor
            download=True                # 是否下载数据集(如果尚未下载的话)
        )
        
        # 创建FashionMNIST数据集的测试集实例
        mnist_test = torchvision.datasets.FashionMNIST(
            root="../data",              # 数据集存放的根目录
            train=False,                 # 表示加载测试集
            transform=trans,             # 数据变换,包括调整大小和转换为Tensor
            download=True                # 是否下载数据集(如果尚未下载的话)
        )
        
        # 创建训练数据迭代器,并指定批次大小、是否打乱顺序和数据加载进程数量
        train_data = data.DataLoader(
            mnist_train,                 # 使用的训练数据集实例
            batch_size,                  # 每个批次的样本数量
            shuffle=True,                # 是否在每个epoch前打乱数据顺序
            num_workers=get_dataloader_workers()  # 数据加载进程数量
        )
        
        # 创建测试数据迭代器,并指定批次大小、不打乱顺序和数据加载进程数量
        test_data = data.DataLoader(
            mnist_test,                  # 使用的测试数据集实例
            batch_size,                  # 每个批次的样本数量
            shuffle=False,               # 不打乱数据顺序
            num_workers=get_dataloader_workers()  # 数据加载进程数量
        )
        
        # 返回训练数据迭代器和测试数据迭代器的元组
        return train_data, test_data
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
  • 相关阅读:
    车内静谧性超越埃尔法?走进腾势D9身价上亿的NVH实验室
    代码整洁之道-读书笔记之单元测试
    串口调试助手和网络调试助手使用总结
    【C/C++笔试练习】OSI分层模型、源端口和目的端口、网段地址、SNMP、状态码、tcp报文、域名解析、HTTP协议、计算机网络、美国节日、分解因数
    ⑥、学习HTML 样式- CSS
    算法中的变形金刚——单纯形算法学习笔记
    通讯网关软件020——利用CommGate X2Mysql实现Modbus TCP数据转储Mysql
    初识JSBridge:从原理到使用(android、ios、js三端互通的工具)
    罗丹明PEG活性酯 RB-PEG-NHS,罗丹明聚乙二醇活性酯,Rhodamine-PEG-NHS
    问题求解:总计600人,每次刀一个奇数位的人,最后剩下谁的概率最高 暴力求解法
  • 原文地址:https://blog.csdn.net/qq_30204431/article/details/133429373