• 使用Dataset 和DataLoader 加载数据集


    一、加载数据过程

    PyTorch 数据加载实用程序的核心是 torch.utils.data.DataLoader 类。 它表示可在数据集上迭代的 Python,并支持

    这些选项由 DataLoader 的构造函数参数配置,构造函数的签名如下:

    如下如显示了dataLoader的过程,shuffle将Dataset里的数据打乱,batch_size=2

    二、模型建立流程

    1、准备数据集(Dataset和DataLoader)2、继承Module类设计自己的模型

    3、使用PyTorch APi 构造损失函数和优化器  4、采用前向传播、返向回馈、更新 反复训练。

    三、代码实现

    import torch.nn
    import numpy as np
    from torch.utils.data import Dataset, DataLoader

    class DiabetesDataset(Dataset):
       
    def __init__(self, filepath):
            xy = np.loadtxt(filepath,
    delimiter=',', dtype=np.float32)
           
    self.len = xy.shape[0]
           
    self.x_data = torch.from_numpy(xy[:, :-1])
           
    self.y_data = torch.from_numpy(xy[:, [-1]])

       
    def __getitem__(self, index):
           
    return self.x_data[index], self.y_data[index]

       
    def __len__(self):
           
    return self.len


    dataset = DiabetesDataset(
    'diabetes.csv.gz')
    train_loader = DataLoader(
    dataset=dataset, batch_size=64, shuffle=True, num_workers=2)


    # 继承类Module,自动会实现反向计算图
    class Model(torch.nn.Module):
       
    # 构造方法
       
    def __init__(self):
           
    super(Model, self).__init__()
           
    self.linear1 = torch.nn.Linear(8, 6)
           
    self.linear2 = torch.nn.Linear(6, 4)
           
    self.linear3 = torch.nn.Linear(4, 1)
           
    self.sigmoid = torch.nn.Sigmoid()

       
    def forward(self, x):
            x =
    self.sigmoid(self.linear1(x))
            x =
    self.sigmoid(self.linear2(x))
            x =
    self.sigmoid(self.linear3(x))
           
    return x


    model = Model()

    criterion = torch.nn.BCELoss(
    size_average=True)
    optimizer = torch.optim.SGD(model.parameters(),
    lr=0.1)

    if __name__=='__main__':
       
    for epoch in range(100):
           
    for i, data in enumerate(train_loader, 0):
               
    #1.prepare data
               
    inputs, labels = data
               
    #2.Forward
                
    y_pred = model(inputs)
                loss = criterion(y_pred, labels)
               
    print(epoch, loss.item())
               
    #3.Backward
               
    optimizer.zero_grad()
                loss.backward()
               
    #4.Update
               
    optimizer.step()

    四、运行结果

  • 相关阅读:
    核心实验16_端口镜像_ENSP
    aliyun Rest ful api V3版本身份验证构造
    Js中常见的数据结构(Data Structure)
    第六章 数学(二)
    Maven 打包方式探究
    GIC/ITS代码分析(13)LPI中断虚拟化之KVM中ITS设备的模拟
    xPortPendSVHandler任务切换流程
    pyenv fails with : ModuleNotFoundError: No module named ‘_ctypes‘ error
    windows mysql解压缩版安装指南
    嵌入式学习笔记
  • 原文地址:https://blog.csdn.net/axiaoquan/article/details/127649061