• 刘二大人 PyTorch深度学习实践 笔记 P5 用PyTorch实现线性回归


    P5 用PyTorch实现线性回归

    1、线性回归基本概念

    概念: 只具有一个神经元的最简单的神经网络
    步骤:

    1. 构建数据集
    2. 设计模型 前馈 用来计算预测y
    3. 构造损失函数和优化器 使用PyTorch的API
    4. 训练周期 前馈(算损失) ⇒ 反馈(算梯度) ⇒ 更新(更新权重)

    2、维度

    输入数据都是矩阵,Linear确定权重和偏置的维度,需要知道 x和z的维度

    loss必须要是标量,对所有loss求均值,才可以用backward()
    在这里插入图片描述

    做转置,拼维度,用y = wTx + b
    在这里插入图片描述

    3、代码实现

    I 魔法函数

    对模型进行实例化,会先初始化,然后调用魔法函数,魔法函数再调用前馈函数

    class Foobar:
    	def __init__(self):
    		super().__init__()
    		print('__init__函数被调用了')
    		pass
    
    	def __call__(self, *args, **kwargs):
    		# *args 可变长参数(无名参数) 把前面n个参数变成n元组
    		# **kwargs 把参数变成词典词典 x=2, y=3 ==> {x:2, y:3}
    		print('__call__函数被调用了')
    		print('Hello:' + str(args[0]))
    		print('Hello:' + str(kwargs))
    		self.forward()
    
    	def forward(self):
    		print('forward函数被调用了')
    
    foobar = Foobar()
    foobar(1, 2, 3, x = 1, y = 2)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19

    输出:

    __init__函数被调用了
    __call__函数被调用了
    Hello:1
    Hello:{'x': 1, 'y': 2}
    forward函数被调用了
    
    • 1
    • 2
    • 3
    • 4
    • 5

    II 线性回归代码实现

    import torch
    import matplotlib.pyplot as plt
    
    # 数据集 3行1列的矩阵
    x_data = torch.tensor([[1.0], [2.0], [3.0]])
    y_data = torch.tensor([[2.0], [4.0], [6.0]])
    
    # 模型定义成一个类,将来可以扩展模型,以适应于各种各样的任务
    # 所有类都要继承Module,里面有好多方法可以直接使用
    # nn 是 Netural Network的缩写
    class LinearModel(torch.nn.Module):
    	# 类中至少要实现两个函数
    	# 1. 构造函数:用来初始化对象
    	def __init__(self):
    		# 调用父类的构造函数 LinearModel:模型名称
    		super(LinearModel, self).__init__()
    
    		# torch.nn.Linear是PyTorch里面的一个类
    		# torch.nn.Linear(1, 1) 在构造一个对象,包含了权重和偏置两个tensor
    		# 可以自动来完成 x*w + b的计算
    		# Linear(1, 1)也是继承自Module,也可以自动进行backward计算
    		self.linear = torch.nn.Linear(1, 1)
    		# 参数:in_features 输入纬度 out_features 输出纬度 bias 默认为True
    
    	# 2. 前馈函数:进行前馈需要执行的计算
    	# pytorch在nn.Module中,实现了__call__方法,而在__call__方法中调用了forward函数
    	# 因此新写的类中需要重写forward()以覆盖掉父类中的forward()
    	# 该函数的另一个作用是可以直接在对象后面加()
    	# 例如实例化的model对象,和实例化的linear对象
    	# pytorch也是按照__init__, __call__, forward三个函数实现网络层之间的架构的
    	def forward(self, x):
    		y_pred = self.linear(x) # 对象后面加(),说明实现了一个可调用的对象
    		return y_pred
    
    # 用Module构造出来的对象,会根据计算图自动进行backward()的计算
    # 如果自己定义的模型,没办法进行求倒数
    # 方法一:模块由PyTorch基本计算封装成类,实例化Module,调用即可
    # 方法二:自己有更快的计算方式,可以在Functions类里构造计算块,进行继承
    # 当然,直接用Module里面的模块算是最简单的
    
    # 模型实例化
    model = LinearModel()
    
    criterion = torch.nn.MSELoss(reduction='sum')
    # size_average 损失是否求均值 没啥用 改成 reduction='sum'
    # 做mini batch 可以设置为True 不过影响也不是很大
    # reduce 用来确定是否降维 不考虑
    
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    # 优化器与Module无关,不会构建计算图
    # 传入数据为单个,为随机梯度下降
    # 传入数据为batch,为批量梯度下降,此处为SGD
    # 类SGD model.parameters() 把linear里的w都拿出来
    # lr 学习率
    
    # 1.先算y_pred
    # 2. loss 梯度清0
    # 3. backward
    # 4. update
    
    epoch_list = []
    loss_list = []
    
    for epoch in range(100):
    	y_pred = model(x_data) # forward 计算y_pred
    	loss = criterion(y_pred, y_data) # 计算损失
    	# print(type(loss))
    	print(epoch, loss.item())
    	
    	epoch_list.append(epoch)
    	loss_list.append(loss.item())
    
    	optimizer.zero_grad() # 梯度归0
    	loss.backward() # backward 计算梯度
    	optimizer.step() # update w, b
    
    # 输出权重和偏置
    print('w=', model.linear.weight.item())
    # weight是一个矩阵,item()才是值
    print('b=', model.linear.bias.item())
    
    # Test Model
    x_test = torch.tensor([[4.0]])
    y_test = model(x_test)
    print('y_pred=', y_test.item())
    
    # 画图
    plt.plot(epoch_list, loss_list)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91

    输出:

    0 84.3682632446289
    1 37.785797119140625
    2 17.045326232910156
    3 7.809038162231445
    4 3.6941323280334473
    5 1.859161138534546
    6 1.0392004251480103
    7 0.6711367964744568
    8 0.5042883157730103
    9 0.4270588159561157
    10 0.3897675573825836
    11 0.37029650807380676
    12 0.35880059003829956
    13 0.3508957624435425
    14 0.34462952613830566
    15 0.3391319811344147
    16 0.3340155780315399
    17 0.329107403755188
    18 0.32432955503463745
    19 0.31964707374572754
    20 0.3150433897972107
    21 0.3105117380619049
    22 0.3060474991798401
    23 0.30164802074432373
    24 0.2973126769065857
    25 0.2930392026901245
    26 0.28882789611816406
    27 0.2846769690513611
    28 0.28058576583862305
    29 0.276553213596344
    30 0.27257853746414185
    31 0.2686613202095032
    32 0.2648002505302429
    33 0.2609947621822357
    34 0.25724393129348755
    35 0.2535468339920044
    36 0.24990268051624298
    37 0.2463115155696869
    38 0.24277149140834808
    39 0.23928232491016388
    40 0.23584377765655518
    41 0.23245413601398468
    42 0.22911357879638672
    43 0.22582058608531952
    44 0.22257545590400696
    45 0.21937674283981323
    46 0.21622377634048462
    47 0.21311627328395844
    48 0.21005336940288544
    49 0.20703470706939697
    50 0.2040593922138214
    51 0.20112663507461548
    52 0.19823598861694336
    53 0.19538730382919312
    54 0.19257915019989014
    55 0.18981142342090607
    56 0.18708369135856628
    57 0.18439501523971558
    58 0.1817447394132614
    59 0.17913301289081573
    60 0.17655833065509796
    61 0.17402119934558868
    62 0.17152008414268494
    63 0.16905491054058075
    64 0.16662533581256866
    65 0.16423071920871735
    66 0.16187059879302979
    67 0.15954403579235077
    68 0.1572512835264206
    69 0.1549912840127945
    70 0.15276378393173218
    71 0.15056845545768738
    72 0.1484045833349228
    73 0.14627176523208618
    74 0.14416931569576263
    75 0.14209739863872528
    76 0.14005550742149353
    77 0.1380425989627838
    78 0.1360587775707245
    79 0.1341031938791275
    80 0.1321759670972824
    81 0.13027633726596832
    82 0.1284041702747345
    83 0.12655889987945557
    84 0.1247401311993599
    85 0.12294728308916092
    86 0.12118029594421387
    87 0.11943873763084412
    88 0.117722287774086
    89 0.11603038758039474
    90 0.11436278373003006
    91 0.11271937191486359
    92 0.11109931766986847
    93 0.10950258374214172
    94 0.10792884975671768
    95 0.10637786984443665
    96 0.10484891384840012
    97 0.10334222763776779
    98 0.10185694694519043
    99 0.10039319097995758
    w= 1.7890673875808716
    b= 0.4794996380805969
    y_pred= 7.635769367218018
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103

    在这里插入图片描述
    训练1000次,效果很好

    增加训练次数,尽可能接近训练集,但是对于测试集可能会过拟合,训练集和开发集上都要观察

    4、作业:不同优化器的损失图差别

    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    LBFGS要传递闭包,暂未解决

    TypeError: step() missing 1 required positional argument: 'closure'
    
    • 1

    PyTorch更多使用可以参考官网 https://pytorch.org/tutorials/beginner/pytorch_with_examples.html

  • 相关阅读:
    gcc: error: : No such file or directory
    ss-4.2 多个eureka集群案例
    Mybatis 10
    Elasticsearch下载
    辣子鸡丁的家常做法 辣子鸡丁怎么做
    工业节碳分论坛精彩回顾 | 第二届始祖数字化可持续发展峰会
    docker和docker compose安装使用、入门进阶案例
    波及Win 11,让安全员自动放弃的零日漏洞,微软这次麻烦了
    C规范编辑笔记(六)
    【计算机视觉40例】案例38:驾驶员疲劳监测
  • 原文地址:https://blog.csdn.net/qq_44948213/article/details/126392027