• Pytorch中DataLoader的使用方法


    在Pytorch中,torch.utils.data中的Dataset与DataLoader是处理数据集的两个函数,用来处理加载数据集。通常情况下,使用的关键在于构建dataset类。

    一:dataset类构建。

    在构建数据集类时,除了__init__(self),还要有__len__(self)与__getitem__(self,item)两个方法,这三个是必不可少的,至于其它用于数据处理的函数,可以任意定义。

    1. class dataset:
    2. def __init__(self,...):
    3. ...
    4. def __len__(self,...):
    5. return n
    6. def __getitem__(self,item):
    7. return data[item]

    正常情况下,该数据集是要继承Pytorch中Dataset类的,但实际操作中,即使不继承,数据集类构建后仍可以用Dataloader()加载的。

    在dataset类中,__len__(self)返回数据集中数据个数,__getitem__(self,item)表示每次返回第item条数据。

    二:DataLoader使用

    在构建dataset类后,即可使用DataLoader加载。DataLoader中常用参数如下:

    1.dataset:需要载入的数据集,如前面构造的dataset类。

    2.batch_size:批大小,在神经网络训练时我们很少逐条数据训练,而是几条数据作为一个batch进行训练。

    3.shuffle:是否在打乱数据集样本顺序。True为打乱,False反之。

    4.drop_last:是否舍去最后一个batch的数据(很多情况下数据总数N与batch size不整除,导致最后一个batch不为batch size)。True为舍去,False反之。

    三:举例

    兔兔以指标为1,数据个数为100的数据为例。

    1. import torch
    2. from torch.utils.data import DataLoader
    3. class dataset:
    4. def __init__(self):
    5. self.x=torch.randint(0,20,size=(100,1),dtype=torch.float32)
    6. self.y=(torch.sin(self.x)+1)/2
    7. def __len__(self):
    8. return 100
    9. def __getitem__(self, item):
    10. return self.x[item],self.y[item]
    11. data=DataLoader(dataset(),batch_size=10,shuffle=True)
    12. for batch in data:
    13. print(batch)

    当然,利用这个数据集可以进行简单的神经网络训练。

    1. from torch import nn
    2. data=DataLoader(dataset(),batch_size=10,shuffle=True)
    3. bp=nn.Sequential(nn.Linear(1,5),
    4. nn.Sigmoid(),
    5. nn.Linear(5,1),
    6. nn.Sigmoid())
    7. optim=torch.optim.Adam(params=bp.parameters())
    8. Loss=nn.MSELoss()
    9. for epoch in range(10):
    10. print('the {} epoch'.format(epoch))
    11. for batch in data:
    12. yp=bp(batch[0])
    13. loss=Loss(yp,batch[1])
    14. optim.zero_grad()
    15. loss.backward()
    16. optim.step()
  • 相关阅读:
    Bootstrap-- 逻辑运算符
    FT 在图像处理中的应用
    Git入门
    FreeBASIC通过Delphi7 DLL调用MS SOAP使用VB6 Webservice
    STM32MP157_TF-A源码编译报错
    122. 买卖股票的最佳时机 II
    【STM32】使用RTE ,从 0 开始创建一个 (keil) ARM MDK工程(纯keil,标准库,以STM32F103C8T6为例)
    WWDC 2024 回顾:Apple Intelligence 的发布与解析
    KVB交易平台:国内三大交易所(上海、深圳、北京)的概要与分析
    【网络编程套接字】基于TCP协议的网络程序
  • 原文地址:https://blog.csdn.net/weixin_60737527/article/details/126754254