• 刘二大人 PyTorch深度学习实践 笔记 P6 逻辑斯蒂回归


    P6 逻辑斯蒂回归

    1、torchversion 提供的数据集

    import torchvision # PyTorch提供的工具包
    
    # 手写数字识别集 0-9
    # root 存储路径 train 训练集还是测试集 download 从网上下载
    train_set = torchvision.datasets.MNIST(root='./dataset/mnist', train=True, download=True)
    test_set = torchvision.datasets.MNIST(root='./dataset/mnist', train=False, download=True)
    
    # 彩色小图片数据集 分成10个分类 猫、狗等...
    train_set = torchvision.datasets.CIFAR10(root='./dataset/cifar10', train=True, download=True)
    test_set = torchvision.datasets.CIFAR10(root='./dataset/cifar10', train=False, download=True)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    2、基本概念

    回归任务: y ∈ R 连续的空间
    逻辑斯蒂回归: 名字叫回归,但是是做分类的,估算y属于哪一个类别,不是让y等于某一个特定值,因为类别无法进行数值之间的大小比较,而是计算属于每一个分类的概率,概率最大的就是分类的结果。

    在这里插入图片描述

    二分类: 只有两个类别的分类问题 且 P(y = 1) + P(y = 0) = 1
    在这里插入图片描述
    现在想计算概率属于[0, 1],而不是实数,使用sigmod()函数将实数空间映射到[0, 1]之间
    在这里插入图片描述

    sigmod函数特征

    1. [-1, 1]
    2. 单调增函数
    3. 饱和函数
      在这里插入图片描述
      所以逻辑斯蒂回归就是在线性回归的基础上增加一个sigmod函数,保证输出值在0 ~ 1之间
      在这里插入图片描述
      比较两个分布之间的差异,二分类可以使用交叉熵损失BCE函数,预测值与标签越接近,损失值越小
      在这里插入图片描述

    3、代码实现

    import torch.nn
    # import torch.nn.functional as F
    import numpy as np
    import matplotlib.pyplot as plt
    
    # 建立模型
    class LogisticRegressionModel(torch.nn.Module):
    	def __init__(self):
    		super(LogisticRegressionModel, self).__init__()
    		self.linear = torch.nn.Linear(1, 1)
    
    	def forward(self, x):
    		# y_pred = F.sigmoid(self.linear(x))
    		y_pred = torch.sigmoid(self.linear(x))
    		return y_pred
    
    # 准备数据集
    x_data = torch.tensor([[1.0], [2.0], [3.0]])
    # 需要是floatTensor类型的数据,以下两种方式均可
    # P4代码中有注释解释
    y_data = torch.tensor([[0.], [0.], [1.]])
    # y_data = torch.Tensor([[0], [0], [1]])
    
    model = LogisticRegressionModel()
    
    # criterion = torch.nn.BCELoss(size_average=False) 已被弃用
    criterion = torch.nn.BCELoss(reduction='sum')
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    
    # 训练
    for epoch in range(1000):
    	y_pred = model(x_data)
    	loss = criterion(y_pred, y_data)
    	print(epoch, loss.item())
    
    	optimizer.zero_grad()
    	loss.backward()
    	optimizer.step()
    
    print('w = ', model.linear.weight.item())
    print('b = ', model.linear.bias.item())
    
    # 测试
    x_test = torch.Tensor([[4.0]])
    y_test = model(x_test)
    print('y_pred = ', y_test.item())
    
    # 画图
    x = np.linspace(0, 10, 200) # 返回0-10等间距的200个数
    x_t = torch.Tensor(x).view((200, 1)) # 生成一个200行1列的矩阵tensor
    y_t = model(x_t)
    y = y_t.data.numpy() # 调用numpy将y_t变成n维数组
    
    plt.plot(x, y)
    plt.plot([0, 10], [0.5, 0.5], c='r') # 画线,x取值0-10,y=0.5
    plt.xlabel('Hours')
    plt.ylabel('Probability of Pass')
    plt.grid() # 显示网格线 1=True=默认显示;0=False=不显示
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59

    输出:

    ...
    986 1.0572500228881836
    987 1.0567622184753418
    988 1.056274652481079
    989 1.0557879209518433
    990 1.055301547050476
    991 1.054815649986267
    992 1.0543304681777954
    993 1.0538458824157715
    994 1.0533615350723267
    995 1.0528780221939087
    996 1.0523948669433594
    997 1.0519124269485474
    998 1.0514304637908936
    999 1.050948977470398
    w =  1.1907615661621094
    b =  -2.876981258392334
    y_pred =  0.8683062195777893
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18

    在这里插入图片描述

  • 相关阅读:
    使用docker部署lnmp多站点
    ILRuntime热更的小技巧
    ...spread、命名空间、假报错
    H12-821_29
    手把手带你用Python和文心一言搭建《AI看图写诗》网页项目(附上完整项目源码)
    2.继承总结方法
    08-循环神经网络实现文本情感分类
    Spring Boot 7 微服务执行Bot代码(传递路线是难点)
    【推荐系统】推荐系统-基础算法 冷启动、及深度学习在冷启动上的应用
    悬浮窗开发设计实践
  • 原文地址:https://blog.csdn.net/qq_44948213/article/details/126404028