• 【花书笔记|PyTorch版】手动学深度学习303: 线性神经模型:代码部分(下)


    3.5 图像分类数据集

    这章主要是讲对数据的一些处理和准备

    数据集:使用类似但更复杂的Fashion-MNIST数据集

    在这里插入图片描述

    %matplotlib inline
    import torch
    
    # torchvision:pytorch视觉实现的一个库
    import torchvision
    #from torchvision import datasets
    from torch.utils import data
    
    # transforms对数据进行操作的模组
    from torchvision import transforms
    from d2l import torch as d2l
    
    # 这个函数也是画图的函数
    d2l.use_svg_display()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14

    ① 将数据从torch库中下载下来,并转变成pytorch的数据类型torchvision.datasets

    torchvision:https://blog.csdn.net/wohu1104/article/details/107743290

    知识点1; 我们使用是配套的

    • 如果导入的是import torchvision,那么使用你面的函数就是torchvision.datasets.XXX;
    • 如果导入的是from torchvision import datasets,那么使用你面的函数就是datasets.XXX;

    同理torchvision下的主要函数transform

    # 通过ToTensor实例将图像数据从PIL类型变换成float32格式,
    # 并除以255使得所有像素的数值均在0到1之间
    trans = transforms.ToTensor()
    
    # 准备训练集 和测试机
    # 训练集下载的地方;是否是训练集;下载时候转变成32位浮点数格式;是否下载
    mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)
    mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 【观察数据】

    Fashion-MNIST由10个类别的图像组成, 每个类别由训练数据集(train dataset)中的6000张图像 和测试数据集(test dataset)中的1000张图像组成。 因此,训练集和测试集分别包含60000和10000张图像。 测试数据集不会用于训练,只用于评估模型性能。

    # 看看训练集 测试集长度
    len(mnist_train), len(mnist_test)
    
    • 1
    • 2
    (60000, 10000)
    
    • 1

    个输入图像的高度和宽度均为28像素。 数据集由灰度图像组成,其通道数为1。 为了简洁起见,本书将高度h像素、宽度w像素图像的形状记为或(h,w)

    mnist_train[0][0].shape
    
    • 1
    torch.Size([1, 28, 28])
    
    • 1
    • mnist_train[0][0]第一个是取第几张照片,范围是0~59999;60000会报错,因为一共只有60000张图
    • 第二个取值0图片数据1标签数据

    下面是第1张图属于第9类;第60000张图属于第5

    mnist_train[0][1]
    
    • 1
    9
    
    • 1
    mnist_train[59999][1]
    
    • 1
    5
    
    • 1

    ② 处理分类,用get_fashion_mnist_labels将分类数字0-9与具体名称一一对应

    Fashion-MNIST中包含的10个类别,分别为t-shirt(T恤)、trouser(裤子)、pullover(套衫)、dress(连衣裙)、coat(外套)、sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)和ankle boot(短靴)。 以下函数用于在数字标签索引及其文本名称之间进行转换。

    def get_fashion_mnist_labels(labels):  #@save
        """返回Fashion-MNIST数据集的文本标签"""
        text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                       'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
        return [text_labels[int(i)] for i in labels]
    
    • 1
    • 2
    • 3
    • 4
    • 5

    ③ 名称有了,用show_images 来呈现图像

    参数:图片 展示成几行几列 标题默认没有 规模(尺寸大小)

    def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):  #@save
    
        # 设置图片尺寸:scale就是每张图的大小(基数),这里是1.5;改大图像就变大
        figsize = (num_cols * scale, num_rows * scale)
        
         # 这里的 _ 表示忽略不使用的变量、即fig;
         # d2l.plt.subplots()把多张图拼成一张,其中figsize把上面尺寸传下来
        _, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
        
        #把一张图的数据拉直(成):在用plt.subplots画多个子图中,ax = ax.flatten()将ax由n*m的Axes组展平成1*nm的Axes组
        # 这一步就是 不用在二维数组来定位第几个图了 直接第几个图,会自动落位?
        axes = axes.flatten()
        #print("axes",axes[1][0])
        
        for i, (ax, img) in enumerate(zip(axes, imgs)):
            
            
            # i:第几个图
            # imgs:是一个图的数据,是个28*28的二维数组
    
            # 判断传入的图片是否为张量
            if torch.is_tensor(img):
                # 图片张量
                ax.imshow(img.numpy())
            else:
                # PIL图片
                ax.imshow(img)
                
                # 设取消横纵坐标上的刻度(横、纵轴均为28)
            ax.axes.get_xaxis().set_visible(False)
            ax.axes.get_yaxis().set_visible(False)
            if titles:
                ax.set_title(titles[i])
        return axes
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 把多张图拼成一张图的函数:fig, ax = plt.subplots():https://www.cnblogs.com/komean/p/10670619.html

      fig, ax = plt.subplots(1,3): 其中参数1和3分别代表子图的行数和列数,一共有 1x3 个子图像。函数返回一个figure图像和子图ax的array列表。

      fig, ax = plt.subplots(1,3,1): 最后一个参数1代表第一个子图。
      如果想要设置子图的宽度和高度可以在函数内加入figsize值

      fig, ax = plt.subplots(1,3,figsize=(15,7)): 这样就会有1行3个15x7大小的子图。【本文用法】

    • axes = axes.flatten():https://blog.csdn.net/weixin_38314865/article/details/84785141

    • zip():https://blog.csdn.net/lanmy_dl/article/details/124216431

    • enumerate()遍历: https://www.runoob.com/python/python-func-enumerate.html

    • 结合使用;https://blog.csdn.net/weixin_43408110/article/details/87731547

    ④ 来设计上一步的X:imagsy:titles

    # 取出X y 用next(iter)搞了第一组数据,数据一组18个
    X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
    
    #我们一个`X:torch.Size([18, 1, 28, 28])`传进来18张照片 每张[1*28*28]
    X.shape
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    torch.Size([18, 1, 28, 28])
    
    • 1
     # 不要色彩通道了 直接这组张数 和 每张照片的尺寸;一排9张 一o共2排
    show_images(X.reshape(18,28, 28), 2,9, titles=get_fashion_mnist_labels(y));
    
    • 1
    • 2

    在这里插入图片描述

    3.5.2 读取小批量

    # 准备训练数据
    batch_size = 256
    
    def get_dataloader_workers():  #@save
        """使用4个进程来读取数据"""
        return 4
    
    # batch_size一组256个,shuffle随机取,几个进程做
    train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,
                                 num_workers=get_dataloader_workers())
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    # 测试了一下训练的时间
    timer = d2l.Timer()
    for X, y in train_iter:
        continue
    f'{
         timer.stop():.2f} sec'
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    '5.81 sec'
    
    • 1

    3.5.3. 整合所有组件

    # resize 就是我们输入照片是28*28,如果想把尺寸变大 用到resize
    
    def load_data_fashion_mnist(batch_size, resize=None):  #@save
        
    • 1
    • 2
    • 3
  • 相关阅读:
    机器学习 —— 计算评估指标
    安装软件显示“为了对电脑进行保护,已阻止此应用”——已解决
    【Eureka】【源码+图解】【八】Eureka客户端的服务获取
    【沧元图】玉阳宫主是正是邪,和面具人有勾结吗?现在已有答案了
    基于ssm的养老智慧服务平台毕业设计源码071526
    Servlet | HttpServlet源码分析、web站点的欢迎页面
    Final Cut Pro使用教程
    sql注入
    爬取某网站计算机类图书
    数据科学技术与应用——第2章 多维数据结构与运算
  • 原文地址:https://blog.csdn.net/wistonty11/article/details/127778372