• 【深度学习3】线性回归与逻辑回归


     

    🍊本文详细介绍了线性回归和逻辑回归是什么,并对二者进行了对比,此外详细介绍了sigmoid函数,最后使用Pytorch来实战模拟二者

    🍊实验一使用Pytorch实现线性回归

    🍊实验二使用Pytorch实现逻辑回归

    一、Introduction

    在初学机器学习的时候,我们总会听到线性回归模型和逻辑回归模型,那么它们究竟是什么?有什么联系与区别呢?

    首先他们都是一条线,目的是让我们预测的对象尽可能多的穿过这条线

    线性回归模型是真的回归模型,但是逻辑回归模型虽然名字上带着回归,实际上是个分类模型,看看看下图就知道了。因为是完全属于不同类型的模型,因此他们的损失函数也是不一样的

    二、Principle

    2.1 线性回归模型

    损失函数

    2.2 逻辑回归模型

    这里的σ是sigmoid函数, 分类模型是需要将最后的预测结果仿射到0-1区间中,且所有的类的预测值之和为1,因此sigmoid函数最主要的作用就是将结果尽可能的仿射到0-1区间中

    损失函数

    2.3 sigmoid函数

    sigmoid型函数是指一类S型曲线函数,两端饱和。其中最出名的是logistic函数,因此很多地方直接将sigmoid函数默认为logistic函数。

    Logistic函数公式

    Logistic函数的特点是区间范围在0~1之间,而且当x等于0时,其函数值为0.5,主要用作二分类。

     Tanh函数公式

    Tanh函数的特点是区间范围在-1~1之间,而且当x等于0时,其函数值为0

    Hard-Logistic函数和Hard-Tanh函数

    三、Experiment

    伪代码

    1 Prepare dataset

    2 Design model using Class

    3 Construct loss and optimizer(Using Pytorch API)

    4 Training cycle(forward,backward,update)

    3.1 Linear  Regression

    1. import torch
    2. # 1 Prepare for the dataset
    3. x_data = torch.tensor([[1.0], [2.0], [3.0]])
    4. y_data = torch.tensor([[2.0], [4.0], [6.0]])
    5. # Define the Model
    6. class LinearModel(torch.nn.Module):
    7. def __init__(self):
    8. super(LinearModel, self).__init__() # 默认写法,一定要有
    9. self.linear = torch.nn.Linear(1, 1, bias=True)
    10. # 两个参数为输入和输出的维度,N*输入维度和N*输出维度,其模型为y=Ax+b线性模型,因此其输入输出的维度一定是一样的
    11. # bias为模型是否需要偏置b
    12. def forward(self, x):
    13. y_pred = self.linear(x) #
    14. return y_pred
    15. # 这里发现没有BP,这是因为使用Model构造出来的模型会根据你的计算图来自动进行BP
    16. model = LinearModel()
    17. # Define the criterion and optimizer
    18. criterion = torch.nn.MSELoss(size_average=False) # MSELoss是将所有的平方误差相加
    19. optimizer = torch.optim.SGD(model.parameters(), lr=0.01) # 参数:第一个为所优化的模型,第二个是学习率
    20. # Training
    21. for epoch in range(1000):
    22. y_pred = model(x_data)
    23. loss = criterion(y_pred, y_data)
    24. print(epoch, loss)
    25. optimizer.zero_grad()
    26. loss.backward()
    27. optimizer.step()
    28. # Output weight and bias
    29. print('w=', model.linear.weight.item())
    30. print('b=', model.linear.bias.item())
    31. # Test model
    32. x_test = torch.Tensor([[4.0]])
    33. y_test = model(x_test)
    34. print('y_pred=', y_test.data)

    Result

    最后预测的结果w接近2,而b接近0,这是与我们的数据集的情况相匹配的 

    3.2 Logistic Regression

    1. import torch
    2. import torchvision
    3. import torch.nn.functional as F
    4. import numpy as np
    5. import matplotlib.pyplot as plt
    6. # Prepare dataset
    7. x_data = torch.Tensor([[1.0], [2.0], [3.0]])
    8. y_data = torch.Tensor([[0], [0], [1]])
    9. # Define the model
    10. class LogisticRegressionModel(torch.nn.Module):
    11. def __init__(self):
    12. super(LogisticRegressionModel, self).__init__()
    13. self.linear = torch.nn.Linear(1, 1)
    14. def forward(self, x):
    15. y_pred = F.sigmoid(self.linear(x))
    16. return y_pred
    17. model = LogisticRegressionModel()
    18. # Define the criterion and optimizer
    19. criterion = torch.nn.BCELoss(size_average=False)
    20. optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    21. # Training
    22. for epoch in range(1000):
    23. y_pred = model(x_data)
    24. loss = criterion(y_pred, y_data)
    25. print('Epoch[{}/{}],loss:{:.6f}'.format(epoch, 1000, loss.item()))
    26. optimizer.zero_grad()
    27. loss.backward()
    28. optimizer.step()
    29. # Drawing
    30. x = np.linspace(0, 10, 200)
    31. x_t = torch.Tensor(x).view((200, 1))
    32. y_t = model(x_t)
    33. y = y_t.data.numpy()
    34. plt.plot(x, y)
    35. plt.plot([0, 10], [0.5, 0.5], c='r')
    36. plt.xlabel('Hours')
    37. plt.ylabel('Probability of Pass')
    38. plt.show()

    Result 

    可以看到当x等于2.5的时候,预测值刚好为0.5。这是与我们的数据集是相匹配的

    参考资料

    《机器学习》周志华

    《深度学习与机器学习》吴恩达

    《神经网络与与深度学习》邱锡鹏

    《Pytorch深度学习实战》刘二大人

  • 相关阅读:
    【拯救大学生计划】:我做了一个QQ分组神器
    可循环视频播放器丨VideoPlayer丨StreamingAssets加载
    Qt QCustomPlot 点状网格线实现和曲线坐标点拾取
    远程代码执行渗透测试——B模块测试
    在Ubuntu上安装使用PostgreSQL数据库
    Ubuntu安装clickhouse集群
    企业计算机服务器中了mallox勒索病毒怎么解决,勒索病毒解密文件恢复
    java插值查找(含插值查找的代码)
    Java配置25-搭建Jenkins服务器
    数据挖掘比赛比较基础的baseline
  • 原文地址:https://blog.csdn.net/ccaoshangfei/article/details/126754731