• 神经网络和深度学习-加载数据集DataLoader


    加载数据集DataLoader

    Dataloader的概念

    dataloader的主要目标是拿出Mini-Batch这一组数据来进行训练

    在处理多维特征输入这一文章中,使用diabetes这一数据集,在训练时我们使用的是所有的输入x,在梯度计算采用的是随机梯度下降(SDG),每次选用一个样本来进行梯度计算,但存在缺点,优化时间过长

    而在Mini-Batch中我们选择小批量中的所有样本,可以最大化的利用向量的优势,来提升计算速度

    在使用Mini-Batch我们要了解三个概念

    • Epoch

    • Batch-Size

    • Iterations

    首先我们来看一下Epoch,我们采用Mini-Batch之后要使用一个嵌套循环,内循环是每一次迭代都执行一个Mini-Batch,这两个循环相当于把所有的Mini-Batch都跑了一遍

    在这里插入图片描述

    Epoch的定义就是:所有训练样本都进行一次前向传播和反向传播的过程

    Batch-Size的定义是:进行一次前馈和反馈的训练样本数量

    Iterations的定义是:所有的样本/Batch-Size

    Dataloader的作用

    我们要做小批量的训练时,要确定一些重要的参数

    • batch-size

    • shuffle:打乱顺序,为了提高数据样本的随机性可以选择对数据集进行shuffle

    • num_workers :并行操作的数量

    • [i]:支持索引

    • len:长度

    在这里插入图片描述

    定义Dataset和DataLoader

    我们来看一下代码中是如何定义dataset的,在torch.utils.data工具包中包含了这两个类

    其中dataset是一个抽象的类,不能实例化,只能被其他的子类继承,我们想要使用的时候必须定义一个自己 的类来继承使用

    dataloader是用来帮助加载数据的,我们可以实例化一个dataloader

    例如下面自定义一个DiabetesDataset的类

    在这里插入图片描述

    getitem这个方法是一个模板方法,是为了实例化这个对象之后能够支持下标操作,通过索引来取出数据

    len这个方法同样是模板方法,为了返回数据集中的数据条数

    接下来就可以用自定义的DiabetesDataset类来实例化dataset对象

    我们在构造数据集的时候一般有两种选择

    • 把所有数据在init中加载进来,放入内存中,再用getitem根据索引传出数据,适用于数据集本身的容量不大

    • 类似于图像、语音这种非结构的大数据集,不能一次性加载到内存中时,定义一个列表,数据集里面得每一条数据的文件名放入相应的列表中

    📌我们在windows中使用num_workers进行训练会报错,原因是在windows下和Linux下的进程库是不一样的。所以用spawn替代了fork,所以其中处理的方式不同,会出现RuntimeError

    📌解决方法:将要训练的代码train_loader进行封装起来(if语句或者是函数中)

    在这里插入图片描述

    我们在代码中进行改动

    在这里插入图片描述

    数据集的实现

    在构造函数中我们需要一个filepath:描述文件来自什么地方,其次需要通过self.len来获取数据集的长度

    在这里插入图片描述

    DataLoader的使用

    使用enumerate可以获得当前迭代的次数,train_loader中拿出来的元组(x,y)放入data中,所以在训练之前把inputs(x_data)和labels(y_data)从data中取出,此时这两个数据都是Tensor。

    也可以一开始就在for循环中使用i,(x,y),就可以省去下面那句

    在这里插入图片描述

    完整代码

    import numpy as np
    import torch
    from torch.utils.data import Dataset, DataLoader
    
    
    # prepare dataset
    class DiabetesDataset(Dataset):
        def __init__(self, filepath):
            xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)
            self.len = xy.shape[0]  # shape(多少行,多少列)
            self.x_data = torch.from_numpy(xy[:, :-1])
            self.y_data = torch.from_numpy(xy[:, [-1]])
    
        def __getitem__(self, index):
            return self.x_data[index], self.y_data[index]
    
        def __len__(self):
            return self.len
    
    
    dataset = DiabetesDataset('diabetes.csv.gz')
    train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=2)  # num_workers 多线程
    
    
    # design model using class
    class Model(torch.nn.Module):
        def __init__(self):
            super(Model, self).__init__()
            self.linear1 = torch.nn.Linear(8, 6)
            self.linear2 = torch.nn.Linear(6, 4)
            self.linear3 = torch.nn.Linear(4, 1)
            self.sigmoid = torch.nn.Sigmoid()
    
        def forward(self, x):
            x = self.sigmoid(self.linear1(x))
            x = self.sigmoid(self.linear2(x))
            x = self.sigmoid(self.linear3(x))
            return x
    
    
    model = Model()
    
    
    # construct loss and optimizer
    criterion = torch.nn.BCELoss(size_average=True)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    
    
    # training cycle forward, backward, update
    if __name__ == '__main__':
        for epoch in range(100):
            for i, data in enumerate(train_loader, 0):  # train_loader 是先shuffle后mini_batch
                # 1. prepare data
                x_data, y_data = data
                # 2. Forward
                y_pred = model(x_data)
                loss = criterion(y_pred, y_data)
                print(epoch, i, loss.item())
                # 3. Backward
                optimizer.zero_grad()
                loss.backward()
                # 4. Update
                optimizer.step()
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
  • 相关阅读:
    7. 微服务之Docker自动化部署
    几款常用database的性能对比
    注册公司资本认缴和实缴有何区别?
    2121. 相同元素的间隔之和-哈希表法
    Centos7 部署 Stable Diffusion
    Golang 汇编asm语言基础学习
    便携烙铁开源系统IronOS,支持多款便携DC, QC, PD供电烙铁,支持所有智能烙铁标准功能
    golang优化命令执行
    Juniper防火墙SSG-140 session 过高问题
    6、项目第六阶段——用户名登录显示和注册验证码
  • 原文地址:https://blog.csdn.net/weixin_55500281/article/details/128090008