• 利用torch.nn实现logistic回归在人工构造的数据集上进行训练和测试


    利用torch.nn实现logistic回归在人工构造的数据集上进行训练和测试:

    1)(2)(3)同手动实现

    4)构建内置迭代器

    1. #构建迭代器  
    2. lr = 0.03  
    3. batch_size = 10  
    4. 将训练数据的特征和标签组合  
    5. dataset = Data.TensorDataset(x, y)  
    6.   
    7.  dataset 放入 DataLoader  
    8. data_iter = Data.DataLoader(  
    9.     dataset=dataset, # torch TensorDataset format  
    10.     batch_size=batch_size, # mini batch size  
    11.     shuffle=True, 是否打乱数据 (训练集一般需要进行打乱)  
    12.     num_workers=0, 多线程来读数据,注意在Windows下需要设置为0  

    5)构建Logistic模型

    1. #构建Logistic模型  
    2. class LogisticNet(torch.nn.Module):  
    3.     def __init__(self, n_feature):  
    4.         super(LogisticNet, self).__init__()  
    5.         self.linear = torch.nn.Linear(n_feature, 1)  
    6.         self.sigmoid = torch.nn.Sigmoid()  
    7. # forward 定义前向传播  
    8.     def forward(self, x):  
    9.         x = self.linear(x)  
    10.         x = self.sigmoid(x)  
    11.         return x  
    12. net = LogisticNet(2)  

    6)参数初始化和定义损失函数及优化方法

    1. #参数初始化  
    2. init.normal_(net.linear.weight, mean=0, std=1.0)  
    3. init.constant_(net.linear.bias, val=0)  也可以直接修改biasdata:net[0].bias.data.fill_(0)  
    4. #损失函数和优化方法  
    5. loss = torch.nn.BCELoss()  
    6. optimizer = optim.SGD(net.parameters(), lr=0.03) #梯度下降的学习率指定为0.03  

    7)开始训练并输出每轮最后一批次训练集的平均损失

    1. #开始训练并计算每轮损失  
    2. num_epochs = 20  
    3. for epoch in range(1, num_epochs + 1):  
    4.     for X, Y in data_iter:  
    5.         output = net(X)  
    6.         l = loss(output, Y.view(-1, 1))  
    7.         optimizer.zero_grad() 梯度清零,等价于net.zero_grad()  
    8.         l.backward()  
    9.         optimizer.step()  
    10.     print('epoch %d, loss: %f' % (epoch, l.item()))#仅最后一批训练集的损失的均值  
    11. 训练集上的正确率  
    12.     allTrain = 0  
    13.     rightTrain = 0  
    14.     for train_x, train_y in data_iter:  
    15.         allTrain += len(train_y)  
    16.         train_out = net(train_x)  
    17.         mask = train_out.ge(0.5).float()  
    18.         correct = (mask.view(-1, 1) == train_y.view(-1, 1)).sum()  
    19.         rightTrain += correct.float().sum()  
    20.     print('train accuracy: %f' % (rightTrain/allTrain))
  • 相关阅读:
    Redis 线程模型
    【图解大数据技术】流式计算:Spark Streaming、Flink
    【手写一个SpringBoot简易版框架】
    使用pro-components遇到的问题
    QT:QML中使用Loader加载界面
    来自BAT的一份Java高级开发岗面试指南
    大语言模型开发各个阶段的评估方法(未完)
    下厨房网站月度最佳栏目菜谱数据获取及分析PLus
    java计算机毕业设计旅游信息分享网站源码+mysql数据库+系统+lw文档+部署
    【愚公系列】华为云系列之DevCloud+ECS+MySQL搭建超级冷笑话网站【开发者专属集市】
  • 原文地址:https://blog.csdn.net/ccyyll1/article/details/126020626