• B站刘二大人-数据集及数据加载 Lecture 8



    系列文章:


    文章目录


    y_pred = model(x_data)是 使用所有的数据
    想进行批处理,了解几个概念
    在这里插入图片描述
    import torch
    from torch.utils.data import Dataset #Dataset抽象子类,需要继承
    from torch.utils.data import DataLoader #DataLoade用来加载数据

    在这里插入图片描述
    def getitem(self, index):

    def len(self): 返回数据集长度
    dataset = DiabetesDataset() 构造DiabetesDataset对象
    train_loader = DataLoader(dataset=dataset,
    batch_size=32,
    shuffle=True,
    num_workers=2) 初始化参数

    import numpy as np
    import torch
    import matplotlib.pyplot as plt
    # Dataset是抽象类
    from  torch.utils.data import  Dataset
    # DataLoader 是抽象类
    from  torch.utils.data import DataLoader
    
    class LogisticRegressionModel(torch.nn.Module):
        def __init__(self):
            super(LogisticRegressionModel, self).__init__()
            # 输入维度8输出维度6
            self.lay1 = torch.nn.Linear(8,6)
            self.lay2 = torch.nn.Linear(6,4)
            self.lay3 = torch.nn.Linear(4,1)
            self.sigmod = torch.nn.Sigmoid()
    
        def forward(self,x):
            x = self.sigmod(self.lay1(x))
            x = self.sigmod(self.lay2(x))
            x = self.sigmod(self.lay3(x))
            return  x
    
    class DiabetesDataset(Dataset):
        def __init__(self, filepath):
            xy = np.loadtxt(filepath,  delimiter=',', dtype=np.float32)
            self.len = xy.shape[0]
            self.x_data = torch.from_numpy(xy[:,:-1])
            self.y_data = torch.from_numpy(xy[:, [-1]])
        def __getitem__(self, index):
            return self.x_data[index], self.y_data[index]
    
        def __len__(self):
            return  self.len
    
    
    dataset = DiabetesDataset("./datasets/diabetes.csv.gz")
    train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True)
    model = LogisticRegressionModel()
    criterion = torch.nn.BCELoss(reduction='mean')
    optimizer = torch.optim.SGD(model.parameters(), lr=0.005)
    epoch_list = []
    loss_list = []
    for epoch in range(100):
        for i, data in enumerate(train_loader, 0):
    #         1-加载数据
            inputs, label = data
    #         2-forward
            y_pred = model(inputs)
            loss = criterion(y_pred, label)
            epoch_list.append(epoch)
            loss_list.append(loss.item())
            optimizer.zero_grad()
            # 3-反向传播
            loss.backward()
            # Update
            optimizer.step()
    
    plt.plot(epoch_list, loss_list)
    plt.xlabel("epoch")
    plt.ylabel("loss")
    plt.show()
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63

    在这里插入图片描述
    MNIST数据集导入

    import torch
    from  torch.utils.data import  DataLoader,Dataset
    from torchvision import datasets,transforms
    
    train_dataset = datasets.MNIST(root='./datasets/mnist', train=True,
                                   transform=transforms.ToTensor(),
                                   download=True)
    test_dataset = datasets.MNIST(root='./datasets/mnist', train=False,
                                  transform=transforms.ToTensor(),
                                  download=True)
    train_loader = DataLoader(dataset=datasets, batch_size=32,
                              shuffle=True)
    
    test_loader = DataLoader(dataset=test_dataset, batch_size=32,
                             shuffle=False)
    for batch_idx, (inouts, target) in enumerate(test_loader):
        pass
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
  • 相关阅读:
    第20节-PhotoShop基础课程-橡皮檫工具
    基于python下django框架 实现校园教室实验室预约系统详细设计
    k8s部署 多master节点负载均衡以及集群高可用
    网络面试-ox07http中的keep-alive以及长/短连接
    java-php-net-python-校园后勤计算机毕业设计程序
    vue3 - diff算法之快速diff算法
    C Primer Plus(6) 中文版 第2章 C语言概述 2.7 调试程序
    java——注释与空行
    Spring学习笔记(1)
    网络安全——(黑客)自学
  • 原文地址:https://blog.csdn.net/weixin_42382758/article/details/125594686