• 深度学习(PyTorch)——Dataset&DataLoader 加载数据集


    B站up主“刘二大人”视频 笔记

    说在前面:
    本节内容,主要是把数据集写成了一个类,这个类要继承Dataset类,有点像DIY一个数据集的感觉,只有自定义了之后,才能实例化,然后把之前直接在文件夹中读取数据的方式进行了修改;

    后面加载数据的DataLoader(注意L大写),直接可以调用对数据集类做了实例化的对象,即把他当做一个参数,传入DataLoader当中;

    详细过程:
    本课程的主要任务是通过将原本简单的标量输入,升级为向量输入,构建线性传播模型:
    在导入数据阶段就有很大不同:
    数据集类里面有三个函数,这三个函数较为固定,分别自己的作用;
    继承Dataset后我们必须实现三个函数:
    __init__()是初始化函数,之后只要提供数据集路径,就可以进行数据的加载,也就是说,传入init的参数,只要有一个文件路径就可以了;
    getitem__()通过索引找到某个样本;
    __len__()返回数据集大小;

    程序如下:

    1. import torch
    2. import numpy as np
    3. from torch.utils.data import Dataset # s数据工具提供了2个类,一个是Dataset,另一个是DataLoader
    4. from torch.utils.data import DataLoader # 可以帮助我们加载数据
    5. # Dataset是抽象类,不能实例化,只能被继承。DataLoader可以实例化
    6. class DiabetesDataset(Dataset): # 继承Dataset
    7. def __init__(self, filepath):
    8. xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)
    9. self.len = xy.shape[0] # xy是一个n*9的矩阵,xy.shape=(N,9)的元组,xy.shape[0]=N
    10. self.x_data = torch.from_numpy(xy[:, :-1]) # 所有行,从第一列开始,最后一列不要 xdata与ydata的结果均为tensor
    11. self.y_data = torch.from_numpy(xy[:, [-1]]) # 所有行,只要最后一列,-1加了中括号是为了拿出的数据是矩阵
    12. def __getitem__(self, index): # 可以支持下标索引寻找数据
    13. return self.x_data[index], self.y_data[index]
    14. def __len__(self): # 返回数据集条数
    15. return self.len
    16. dataset = DiabetesDataset('diabetes.csv.gz')
    17. train_loader = DataLoader(
    18. dataset=dataset, # 传递数据集
    19. batch_size=32, # 容量是多少
    20. shuffle=True, # 是否打乱
    21. num_workers=2) # 进程为2,是否并行
    22. class Model(torch.nn.Module):
    23. def __init__(self):
    24. super(Model, self).__init__()
    25. self.linear1 = torch.nn.Linear(8, 6)
    26. self.linear2 = torch.nn.Linear(6, 4)
    27. self.linear3 = torch.nn.Linear(4, 1)
    28. self.sigmoid = torch.nn.Sigmoid() # 给模型添加一个非线性变换
    29. '''self.activate = torch.nn.ReLU()'''
    30. def forward(self, x):
    31. x = self.sigmoid(self.linear1(x))
    32. x = self.sigmoid(self.linear2(x))
    33. x = self.sigmoid(self.linear3(x))
    34. return x
    35. model = Model()
    36. criterion = torch.nn.BCELoss(size_average=True) # 构造损失函数
    37. optimizer = torch.optim.SGD(model.parameters(), lr=0.01) # 构造优化器 lr为学习率
    38. if __name__ == '__main__':
    39. for epoch in range(1000):
    40. for batch_idx, data in enumerate(train_loader, 0): # enumerate获得当前迭代的次数
    41. # for batch_idx, (inputs, labels) in enumerate(train_loader, 0): # enumerate获得当前迭代的次数
    42. # 1.数据的准备
    43. inputs, labels = data # 从data里面拿出x和y,他们都是tensor
    44. # 2.前馈
    45. y_pre = model(inputs) # 在前馈算y_hat
    46. loss = criterion(y_pre, labels) # 计算损失
    47. print(epoch, batch_idx, loss.item())
    48. # 3.反馈
    49. optimizer.zero_grad() # 把所有权重的梯度归零
    50. loss.backward() # 反馈
    51. # 4.更新
    52. optimizer.step() # 更新

    运行结果如下:

    视频截图如下: 

     epoch:当所有的训练本都进行了一次前馈和反馈,即完成一次epoch

     dataset是抽象类不能实例化,只能被其他类继承;dataloader可以实例化

     

     

     

     

     

     

     

    epoch=100表示把所有数据都跑100遍 

     

     

  • 相关阅读:
    C高级day1
    PythonStudy5
    笙默考试管理系统-SMExamination.Model.Notice展示
    适合计算机编程开发的笔记本电脑推荐
    【学习】软件压力测试对软件产品的作用
    2 errors and 0 warnings potentially fixable with the `--fix` option.(Vue后台管理系统)
    line-height用了这么久,你真的了解他么
    月薪11K,国企小哥抛弃“铁饭碗”转行测试,亲身经历告诉你选高薪or稳定~
    一致性 Hash 算法
    java---jar详解
  • 原文地址:https://blog.csdn.net/qq_42233059/article/details/126559003