• 【pytorch笔记】第五篇 torchvision,Dataloader,nn.Module的使用


    1 torchvision数据集介绍

    ① torchvision中有很多数据集,当我们写代码时指定相应的数据集指定一些参数,它就可以自行下载。

    CIFAR-10数据集包含60000张32×32的彩色图片,一共10个类别,其中50000张训练图片,10000张测试图片。

    1.1 torchvision数据集使用

    import torchvision
    help(torchvision.datasets.CIFAR10)
    
    • 1
    • 2
    Output exceeds the size limit. Open the full output data in a text editor
    Help on class CIFAR10 in module torchvision.datasets.cifar:
    
    class CIFAR10(torchvision.datasets.vision.VisionDataset)
     |  `CIFAR10 //www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
     |  
     |  Args:
     |      root (string): Root directory of dataset where directory
     |          ``cifar-10-batches-py`` exists or will be saved to if download is set to True.
     |      train (bool, optional): If True, creates dataset from training set, otherwise
     |          creates from test set.
     |      transform (callable, optional): A function/transform that takes in an PIL image
     |          and returns a transformed version. E.g, ``transforms.RandomCrop``
     |      target_transform (callable, optional): A function/transform that takes in the
     |          target and transforms it.
     |      download (bool, optional): If true, downloads the dataset from the internet and
     |          puts it in root directory. If dataset is already downloaded, it is not
     |          downloaded again.
     |  
     |  Method resolution order:
     |      CIFAR10
     |      torchvision.datasets.vision.VisionDataset
     |      torch.utils.data.dataset.Dataset
     |      typing.Generic
     |      builtins.object
     |  
    ...
     |  
     |  __new__(cls, *args, **kwds)
     |      Create and return a new object.  See help(type) for accurate signature.
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30

    1.2 查看CIFAR10数据集内容

    import torchvision
    train_set = torchvision.datasets.CIFAR10(root="./dataset",train=True,download=True) # root为存放数据集的相对路线
    test_set = torchvision.datasets.CIFAR10(root="./dataset",train=False,download=True) # train=True是训练集,train=False是测试集  
    
    print(test_set[0])       # 输出的3是target 
    print(test_set.classes)  # 测试数据集中有多少种
    
    img, target = test_set[0] # 分别获得图片、target
    print(img)
    print(target)
    
    print(test_set.classes[target]) # 3号target对应的种类
    img.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    Files already downloaded and verified
    Files already downloaded and verified
    (.Image.Image image mode=RGB size=32x32 at 0x1A4275AAF28>, 3)
    ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
    .Image.Image image mode=RGB size=32x32 at 0x1A4275AAA58>
    3
    cat
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    2. Dataloader使用

    ① Dataset只是去告诉我们程序,我们的数据集在什么位置,数据集第一个数据给它一个索引0,它对应的是哪一个数据。

    ② Dataloader就是把数据加载到神经网络当中,Dataloader所做的事就是每次从Dataset中取数据,至于怎么取,是由Dataloader中的参数决定的。

    import torchvision
    from torch.utils.data import DataLoader
    
    # 准备的测试数据集
    test_data = torchvision.datasets.CIFAR10("./dataset",train=False,transform=torchvision.transforms.ToTensor())               
    img, target = test_data[0]
    print(img.shape)
    print(img)
    
    # batch_size=4 使得 img0, target0 = dataset[0]、img1, target1 = dataset[1]、img2, target2 = dataset[2]、img3, target3 = dataset[3],然后这四个数据作为Dataloader的一个返回      
    test_loader = DataLoader(dataset=test_data,batch_size=4,shuffle=True,num_workers=0,drop_last=False)      
    # 用for循环取出DataLoader打包好的四个数据
    for data in test_loader:
        imgs, targets = data # 每个data都是由4张图片组成,imgs.size 为 [4,3,32,32],四张32×32图片三通道,targets由四个标签组成             
        print(imgs.shape)
        print(targets)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16

    这里输出结果信息有点长,可以自行运行看结果

    2.1 Dataloader多轮次

    import torchvision
    from torch.utils.data import DataLoader
    from torch.utils.tensorboard import SummaryWriter
    
    # 准备的测试数据集
    test_data = torchvision.datasets.CIFAR10("./dataset",train=False,transform=torchvision.transforms.ToTensor())               
    # batch_size=4 使得 img0, target0 = dataset[0]、img1, target1 = dataset[1]、img2, target2 = dataset[2]、img3, target3 = dataset[3],然后这四个数据作为Dataloader的一个返回      
    test_loader = DataLoader(dataset=test_data,batch_size=64,shuffle=True,num_workers=0,drop_last=True)      
    # 用for循环取出DataLoader打包好的四个数据
    writer = SummaryWriter("logs")
    for epoch in range(2):
        step = 0
        for data in test_loader:
            imgs, targets = data # 每个data都是由4张图片组成,imgs.size 为 [4,3,32,32],四张32×32图片三通道,targets由四个标签组成             
            writer.add_images("Epoch:{}".format(epoch),imgs,step)
            step = step + 1
        
    writer.close()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18

    在这里插入图片描述

    3. nn.Module模块使用

    ① nn.Module是对所有神经网络提供一个基本的类。

    ② 我们的神经网络是继承nn.Module这个类,即nn.Module为父类,nn.Module为所有神经网络提供一个模板,对其中一些我们不满意的部分进行修改。

    import torch
    from torch import nn
    
    class MyModule(nn.Module):
        def __init__(self):
            super(MyModule, self).__init__()  # 继承父类的初始化
            
        def forward(self, input):          # 将forward函数进行重写
            output = input + 1
            return output
        
    myModule= MyModule()
    x = torch.tensor(1.0)  # 创建一个值为 1.0 的tensor
    output = myModule(x)
    print(output)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    tensor(2.)
    
    • 1

    3.1 super(Myclass, self)._init_()

    ① 简单理解就是子类把父类的__init__()放到自己的__init__()当中,这样子类就有了父类的_init_()的那些东西。

    ② Myclass类继承nn.Module,super(Myclass, self).__init__()就是对继承自父类nn.Module的属性进行初始化。而且是用nn.Module的初始化方法来初始化继承的属性。

    ③ super().__init()__()来通过初始化父类属性以初始化自身继承了父类的那部分属性;这样一来,作为nn.Module的子类就无需再初始化那一部分属性了,只需初始化新加的元素。

    ③ 子类继承了父类的所有属性和方法,父类属性自然会用父类方法来进行初始化。

    3.2 forward函数

    ① 使用pytorch的时候,不需要手动调用forward函数,只要在实例化一个对象中传入对应的参数就可以自动调用 forward 函数。

    ② 因为 PyTorch 中的大部分方法都继承自 torch.nn.Module,而 torch.nn.Module 的__call__(self)函数中会返回 forward()函数 的结果,因此PyTroch中的 forward()函数等于是被嵌套在了__call__(self)函数中;因此forward()函数可以直接通过类名被调用,而不用实例化对象。

    class A():
        def __call__(self, param):
            print('i can called like a function')
            print('传入参数的类型是:{}   值为: {}'.format(type(param), param))
            res = self.forward(param)
            return res
        
        def forward(self, input_):
            print('forward 函数被调用了')
            print('in  forward, 传入参数类型是:{}  值为: {}'.format( type(input_), input_))
            return input_
    
    a = A()
    input_param = a('i')
    print("对象a传入的参数是:", input_param)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    i can called like a function
    传入参数的类型是:<class 'str'>   值为: i
    forward 函数被调用了
    in  forward, 传入参数类型是:<class 'str'>  值为: i
    对象a传入的参数是: i
    
    • 1
    • 2
    • 3
    • 4
    • 5
  • 相关阅读:
    js-数组的方法--4个常用方法
    (4)点云数据处理学习——其它官网例子
    C++核心编程(三十三)容器(list)
    【操作系统】进程同步、进程互斥、死锁
    Clickhouse设置多磁盘存储策略
    中创生日会 | 烟火向星辰,所愿皆成真
    Java编程之道:巧妙解决Excel公式迭代计算难题
    DocuWare平台——文档管理的内容服务和工作流自动化的平台详细介绍(下)
    正则表达式匹配双引号
    深度强化学习(Deep Reinforcement Learning, DRL)阶段性学习汇总(二)
  • 原文地址:https://blog.csdn.net/qq_35793394/article/details/127787699