• 【深度学习 Pytorch笔记 B站刘二大人 数据集加载 Dataset&DataLoader 模块实现与源码解读(7/10)】


    数据集加载 Dataset&DataLoader 模块实现与源码详解 深度学习 Pytorch笔记 B站刘二大人 (7/10)

    模块介绍

    在本节中没有关于数学原理的相关介绍,使用的数据集和类型仍然是(6、10)的相关内容。

    在这里主要是介绍dataset ,dataloader两个类以及mini-batch方法

    梯度下降方法:
    1.全部数据都使用,batch,最大化使用向量计算优势,提升计算速度
    2.随机梯度下降,只用一个样本,会得到较好的随机性,克服鞍点的问题。

    Mini-batch :将1,2进行结合,进行两层循环嵌套,在每个epoch中执行一次mini-batch。
    每个epoch中进行一次完整的forward和backward,iteration = Batch / miniBatch
    在这里插入图片描述

    Shuffle主要用于将数据集随机打乱,loader将打乱数据集根据size相组合

    在这里插入图片描述

    代码解读与模块实现

    注意点:dataset为抽象类,不能实例化,但是dataloader可以进行实例化操作

    魔法方法,python文件的内置方法,重写后的num workers参数决定进行多线程读入的线程个数

    在这里插入图片描述

    有时会出现代码报错的情况,且仅在Windows下会报错

    原因是数据类型在Windows系统下应该用spwan代替fork

    解决办法:
    在这里插入图片描述

    重写dataloader类:

    在这里插入图片描述

    训练过程

    在这里插入图片描述

    整体代码

    ''' coding:utf-8 '''
    
    """
    作者:shiyi
    日期:年 09月 03日
    pytorch加载数据集,重写Dataset函数导入数据
    """
    
    import numpy as np
    import torch
    from torch.utils.data import Dataset    # 引入抽象类dataset
    from torch.utils.data import DataLoader    # 帮助将数据导入pytorch完成数据类型的转换
    
    
    class DiabetesDataset(Dataset):         # 重写dataset类
        def __init__(self, filepath):       # 构造函数中加入路径参数
            xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)      # 设置读取数据类型为float32以满足GPU迁移的需求
            self.len = xy.shape[0]          # 设置内参length
            self.x_data = torch.from_numpy(xy[:, :-1])      # 输入数据读取 除最后一列 的所有行数据
            self.y_data = torch.from_numpy(xy[:, [-1]])     # 标签数据读取 最后一列 的所有行数据
    
        def __getitem__(self, index):       # 设置函数接口,可以输出训练数据的索引
            return self.x_data[index], self.y_data[index]
    
        def __len__(self):                  # 设置接口可以输出数据length
            return self.len
    
    
    dataset = DiabetesDataset('D:\\pytorch_prac\\dataset\\diabetes.csv.gz')     # 读取数据
    train_loader = DataLoader(dataset=dataset,              # 训练数据
                              batch_size=32,                # 设置batch 一次训练数据次数
                              shuffle=True,                 # 是否打乱顺序 是
                              num_workers=2)                # 双线程进行
    
    
    class Model(torch.nn.Module):                   # 构建深度学习模型
        def __init__(self):                         # 注释已经在(6/10)中详细写过,不在赘述
            super(Model, self).__init__()
            self.linear1 = torch.nn.Linear(8, 6)
            self.linear2 = torch.nn.Linear(6, 4)
            self.linear3 = torch.nn.Linear(4, 1)
            self.sigmoid = torch.nn.Sigmoid()       #self.activate = torch.nn.ReLU()
    
        def forward(self, x):
            x = self.sigmoid(self.linear1(x))
            x = self.sigmoid(self.linear2(x))
            x = self.sigmoid(self.linear3(x))
            return x
    
    
    model = Model()     # 实例化
    
    criterion = torch.nn.BCELoss(size_average=True)
    optimizer = torch.optim.ASGD(model.parameters(), lr=0.1)
    
    if __name__ == '__main__':
        for epoch in range(100):            # 注释已经在(6/10)中详细写过,不在赘述
            for i, data in enumerate(train_loader, 0):
                # 1.prepare data
                inputs, labels = data
                # 2.Forward
                y_pred = model(inputs)
                loss = criterion(y_pred, labels)
                print(epoch, i, loss.item())
                # Backward
                optimizer.zero_grad()
                loss.backward()
                # Update
                optimizer.step()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69

    运行结果

    99 19 0.5463573336601257
    99 20 0.552997887134552
    99 21 0.5635378360748291
    99 22 0.6222219467163086
    99 23 0.5633459687232971
    
    Process finished with exit code 0
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
  • 相关阅读:
    Sqlserver 多行合并为一行
    JAVA毕业设计考勤系统设计计算机源码+lw文档+系统+调试部署+数据库
    Linux账号密码安全运维
    快速简单制作Mac系统ISO格式镜像之macOS Sonoma
    靠!我被项目经理和同事嘲笑了,因为不会远程debug调试...
    Android Termux安装MySQL,并使用cpolar实现公网安全远程连接[内网穿透]
    Java设计模式-中介者模式
    【MATLAB源码-第56期】基于WOA白鲸优化算法和PSO粒子群优化算法的三维路径规划对比。
    git - rebase 使用
    2023工博会,正运动开放式激光振镜运动控制器应用预览(三)
  • 原文地址:https://blog.csdn.net/qq_43649786/article/details/126880369