• 刘二大人 PyTorch深度学习实践 笔记 P3 梯度下降算法


    P3 梯度下降算法

    1、算法思路及存在的问题

    在这里插入图片描述
    穷举思路: 全部找一遍,找出最优点
    存在的问题:

    • 通常情况下,图形不会这么漂亮光滑,一眼看出最优点
    • 比如多维权重,用穷举法会导致搜索量太大

    分治思路: 分成四份,找16个点,找出其中比较小的块,再分成四份,找点,即局部最优点
    存在的问题:

    • 如果是凸函数,可以用分治法
      在这里插入图片描述
    • 但如果是不规则的函数,可能会错过最优点
      在这里插入图片描述
      优化问题: 求权重函数的最小值
      在这里插入图片描述
      梯度下降: 每次往梯度的负方向走一点
      在这里插入图片描述
      可以得到局部最优点,类似贪心算法,但是没办法保证找到全局最优点
      在这里插入图片描述
      目前很多使用梯度下降算法,因为在实际应用过程中很难陷入局部最优点,不是很多,但是会存在鞍点,即梯度为0,会导致算法无法继续迭代
      在这里插入图片描述
      在这里插入图片描述

    2、梯度下降算法代码实现

    import matplotlib.pyplot as plt
    
    # 模型
    def forward(x):
    	return x * w
    
    # 计算mse
    def cost(xs, ys):
    	cost = 0
    	for x, y in zip(xs, ys):
    		y_pred = forward(x)
    		cost += (y_pred - y) ** 2
    	return cost / len(xs)
    
    # 计算梯度
    def gradient(xs, ys):
    	grad = 0
    	for x, y in zip(xs, ys):
    		grad += 2 * x * (x*w - y)
    	return grad / len(xs)
    
    # 数据集
    x_data = [1.0, 2.0, 3.0]
    y_data = [2.0, 4.0, 6.0]
    
    # 权重
    w = 1.0
    r = 0.01
    print('Predict (before training)', 4, forward(4))
    epoch_list = [] # 轮次列表
    cost_list = [] # mse列表
    for epoch in range(100):
    	cost_val = cost(x_data, y_data)
    	grad_val = gradient(x_data, y_data)
    	w = w - r * grad_val
    	print('Epoch:', epoch, 'w=', w, 'cost=', round(cost_val,3))
    	epoch_list.append(epoch)
    	cost_list.append(cost_val)
    print('Predict (after training)', 4, forward(4))
    
    # 绘图
    plt.plot(epoch_list, cost_list)
    plt.xlabel('Epoch')
    plt.ylabel('Cost')
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45

    输出:

    Predict (before training 4 4.0
    Epoch: 0 w= 1.0933333333333333 cost= 4.667
    Epoch: 1 w= 1.1779555555555554 cost= 3.836
    Epoch: 2 w= 1.2546797037037036 cost= 3.154
    Epoch: 3 w= 1.3242429313580246 cost= 2.592
    Epoch: 4 w= 1.3873135910979424 cost= 2.131
    Epoch: 5 w= 1.4444976559288012 cost= 1.752
    Epoch: 6 w= 1.4963445413754464 cost= 1.44
    Epoch: 7 w= 1.5433523841804047 cost= 1.184
    Epoch: 8 w= 1.5859728283235668 cost= 0.973
    Epoch: 9 w= 1.6246153643467005 cost= 0.8
    Epoch: 10 w= 1.659651263674342 cost= 0.658
    Epoch: 11 w= 1.6914171457314033 cost= 0.541
    Epoch: 12 w= 1.7202182121298057 cost= 0.444
    Epoch: 13 w= 1.7463311789976905 cost= 0.365
    Epoch: 14 w= 1.7700069356245727 cost= 0.3
    Epoch: 15 w= 1.7914729549662791 cost= 0.247
    Epoch: 16 w= 1.8109354791694263 cost= 0.203
    Epoch: 17 w= 1.8285815011136133 cost= 0.167
    Epoch: 18 w= 1.8445805610096762 cost= 0.137
    Epoch: 19 w= 1.8590863753154396 cost= 0.113
    Epoch: 20 w= 1.872238313619332 cost= 0.093
    Epoch: 21 w= 1.8841627376815275 cost= 0.076
    Epoch: 22 w= 1.8949742154979183 cost= 0.063
    Epoch: 23 w= 1.904776622051446 cost= 0.051
    Epoch: 24 w= 1.9136641373266443 cost= 0.042
    Epoch: 25 w= 1.9217221511761575 cost= 0.035
    Epoch: 26 w= 1.9290280837330496 cost= 0.029
    Epoch: 27 w= 1.9356521292512983 cost= 0.024
    Epoch: 28 w= 1.9416579305211772 cost= 0.019
    Epoch: 29 w= 1.9471031903392007 cost= 0.016
    Epoch: 30 w= 1.952040225907542 cost= 0.013
    Epoch: 31 w= 1.9565164714895047 cost= 0.011
    Epoch: 32 w= 1.9605749341504843 cost= 0.009
    Epoch: 33 w= 1.9642546069631057 cost= 0.007
    Epoch: 34 w= 1.9675908436465492 cost= 0.006
    Epoch: 35 w= 1.970615698239538 cost= 0.005
    Epoch: 36 w= 1.9733582330705144 cost= 0.004
    Epoch: 37 w= 1.975844797983933 cost= 0.003
    Epoch: 38 w= 1.9780992835054327 cost= 0.003
    Epoch: 39 w= 1.980143350378259 cost= 0.002
    Epoch: 40 w= 1.9819966376762883 cost= 0.002
    Epoch: 41 w= 1.983676951493168 cost= 0.002
    Epoch: 42 w= 1.9852004360204722 cost= 0.001
    Epoch: 43 w= 1.9865817286585614 cost= 0.001
    Epoch: 44 w= 1.987834100650429 cost= 0.001
    Epoch: 45 w= 1.9889695845897222 cost= 0.001
    Epoch: 46 w= 1.9899990900280147 cost= 0.001
    Epoch: 47 w= 1.9909325082920666 cost= 0.0
    Epoch: 48 w= 1.9917788075181404 cost= 0.0
    Epoch: 49 w= 1.9925461188164473 cost= 0.0
    Epoch: 50 w= 1.9932418143935788 cost= 0.0
    Epoch: 51 w= 1.9938725783835114 cost= 0.0
    Epoch: 52 w= 1.994444471067717 cost= 0.0
    Epoch: 53 w= 1.9949629871013967 cost= 0.0
    Epoch: 54 w= 1.9954331083052663 cost= 0.0
    Epoch: 55 w= 1.9958593515301082 cost= 0.0
    Epoch: 56 w= 1.9962458120539648 cost= 0.0
    Epoch: 57 w= 1.9965962029289281 cost= 0.0
    Epoch: 58 w= 1.9969138906555615 cost= 0.0
    Epoch: 59 w= 1.997201927527709 cost= 0.0
    Epoch: 60 w= 1.9974630809584561 cost= 0.0
    Epoch: 61 w= 1.9976998600690001 cost= 0.0
    Epoch: 62 w= 1.9979145397958935 cost= 0.0
    Epoch: 63 w= 1.9981091827482769 cost= 0.0
    Epoch: 64 w= 1.9982856590251044 cost= 0.0
    Epoch: 65 w= 1.9984456641827613 cost= 0.0
    Epoch: 66 w= 1.9985907355257035 cost= 0.0
    Epoch: 67 w= 1.9987222668766378 cost= 0.0
    Epoch: 68 w= 1.9988415219681517 cost= 0.0
    Epoch: 69 w= 1.9989496465844576 cost= 0.0
    Epoch: 70 w= 1.9990476795699081 cost= 0.0
    Epoch: 71 w= 1.9991365628100501 cost= 0.0
    Epoch: 72 w= 1.999217150281112 cost= 0.0
    Epoch: 73 w= 1.999290216254875 cost= 0.0
    Epoch: 74 w= 1.9993564627377531 cost= 0.0
    Epoch: 75 w= 1.9994165262155628 cost= 0.0
    Epoch: 76 w= 1.999470983768777 cost= 0.0
    Epoch: 77 w= 1.9995203586170245 cost= 0.0
    Epoch: 78 w= 1.9995651251461022 cost= 0.0
    Epoch: 79 w= 1.9996057134657994 cost= 0.0
    Epoch: 80 w= 1.9996425135423248 cost= 0.0
    Epoch: 81 w= 1.999675878945041 cost= 0.0
    Epoch: 82 w= 1.999706130243504 cost= 0.0
    Epoch: 83 w= 1.9997335580874436 cost= 0.0
    Epoch: 84 w= 1.9997584259992822 cost= 0.0
    Epoch: 85 w= 1.9997809729060159 cost= 0.0
    Epoch: 86 w= 1.9998014154347876 cost= 0.0
    Epoch: 87 w= 1.9998199499942075 cost= 0.0
    Epoch: 88 w= 1.9998367546614149 cost= 0.0
    Epoch: 89 w= 1.9998519908930161 cost= 0.0
    Epoch: 90 w= 1.9998658050763347 cost= 0.0
    Epoch: 91 w= 1.9998783299358769 cost= 0.0
    Epoch: 92 w= 1.9998896858085284 cost= 0.0
    Epoch: 93 w= 1.9998999817997325 cost= 0.0
    Epoch: 94 w= 1.9999093168317574 cost= 0.0
    Epoch: 95 w= 1.9999177805941268 cost= 0.0
    Epoch: 96 w= 1.9999254544053418 cost= 0.0
    Epoch: 97 w= 1.9999324119941766 cost= 0.0
    Epoch: 98 w= 1.9999387202080534 cost= 0.0
    Epoch: 99 w= 1.9999444396553017 cost= 0.0
    Predict (after training) 4 7.999777758621207
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102

    在这里插入图片描述
    如果绘制图不是很光滑,可以使用指数加权均值来绘图

    在这里插入图片描述
    训练失败可能是学习率设置太大了
    使用较多的是随机梯度下降:随机选一个样本损失对权重求导然后进行更新
    在这里插入图片描述

    3、随机梯度下降算法代码实现

    import matplotlib.pyplot as plt
    
    def forward(x):
    	return x * w
    
    def loss(x, y):
    	return (forward(x) - y) ** 2
    
    def gradient(x, y):
    	return 2 * x * (forward(x) - y)
    
    x_data = [1.0, 2.0, 3.0]
    y_data = [2.0, 4.0, 6.0]
    
    w = 1.0
    r = 0.01
    print('Predict (before training): ', 4, forward(4))
    
    epoch_list = []
    loss_list = []
    
    for epoch in range(100):
    	for x, y in zip(x_data, y_data):
    		grad = gradient(x, y)
    		w -= r * grad
    		l = loss(x, y)
    		# print('\tgrad:', x, y, grad)
    
    	print('Progress:', epoch, 'w=', w, 'loss=', round(l, 6))
    	epoch_list.append(epoch)
    	loss_list.append(l)
    
    print('Predict (after training): ', 4, forward(4))
    
    plt.plot(epoch_list, loss_list)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38

    输出:

    Predict (before training):  4 4.0
    Progress: 0 w= 1.260688 loss= 4.91924
    Progress: 1 w= 1.453417766656 loss= 2.688769
    Progress: 2 w= 1.5959051959019805 loss= 1.469633
    Progress: 3 w= 1.701247862192685 loss= 0.803276
    Progress: 4 w= 1.7791289594933983 loss= 0.439056
    Progress: 5 w= 1.836707389300983 loss= 0.23998
    Progress: 6 w= 1.8792758133988885 loss= 0.131169
    Progress: 7 w= 1.910747160155559 loss= 0.071695
    Progress: 8 w= 1.9340143044689266 loss= 0.039187
    Progress: 9 w= 1.9512159834655312 loss= 0.021419
    Progress: 10 w= 1.9639333911678687 loss= 0.011707
    Progress: 11 w= 1.9733355232910992 loss= 0.006399
    Progress: 12 w= 1.9802866323953892 loss= 0.003498
    Progress: 13 w= 1.9854256707695 loss= 0.001912
    Progress: 14 w= 1.9892250235079405 loss= 0.001045
    Progress: 15 w= 1.9920339305797026 loss= 0.000571
    Progress: 16 w= 1.994110589284741 loss= 0.000312
    Progress: 17 w= 1.9956458879852805 loss= 0.000171
    Progress: 18 w= 1.9967809527381737 loss= 9.3e-05
    Progress: 19 w= 1.9976201197307648 loss= 5.1e-05
    Progress: 20 w= 1.998240525958391 loss= 2.8e-05
    Progress: 21 w= 1.99869919972735 loss= 1.5e-05
    Progress: 22 w= 1.9990383027488265 loss= 8e-06
    Progress: 23 w= 1.9992890056818404 loss= 5e-06
    Progress: 24 w= 1.999474353368653 loss= 2e-06
    Progress: 25 w= 1.9996113831376856 loss= 1e-06
    Progress: 26 w= 1.9997126908902887 loss= 1e-06
    Progress: 27 w= 1.9997875889274812 loss= 0.0
    Progress: 28 w= 1.9998429619451539 loss= 0.0
    Progress: 29 w= 1.9998838998815958 loss= 0.0
    Progress: 30 w= 1.9999141657892625 loss= 0.0
    Progress: 31 w= 1.9999365417379913 loss= 0.0
    Progress: 32 w= 1.9999530845453979 loss= 0.0
    Progress: 33 w= 1.9999653148414271 loss= 0.0
    Progress: 34 w= 1.999974356846045 loss= 0.0
    Progress: 35 w= 1.9999810417085633 loss= 0.0
    Progress: 36 w= 1.9999859839076413 loss= 0.0
    Progress: 37 w= 1.9999896377347262 loss= 0.0
    Progress: 38 w= 1.999992339052936 loss= 0.0
    Progress: 39 w= 1.9999943361699042 loss= 0.0
    Progress: 40 w= 1.9999958126624442 loss= 0.0
    Progress: 41 w= 1.999996904251097 loss= 0.0
    Progress: 42 w= 1.999997711275687 loss= 0.0
    Progress: 43 w= 1.9999983079186507 loss= 0.0
    Progress: 44 w= 1.9999987490239537 loss= 0.0
    Progress: 45 w= 1.9999990751383971 loss= 0.0
    Progress: 46 w= 1.9999993162387186 loss= 0.0
    Progress: 47 w= 1.9999994944870796 loss= 0.0
    Progress: 48 w= 1.9999996262682318 loss= 0.0
    Progress: 49 w= 1.999999723695619 loss= 0.0
    Progress: 50 w= 1.9999997957248556 loss= 0.0
    Progress: 51 w= 1.9999998489769344 loss= 0.0
    Progress: 52 w= 1.9999998883468353 loss= 0.0
    Progress: 53 w= 1.9999999174534755 loss= 0.0
    Progress: 54 w= 1.999999938972364 loss= 0.0
    Progress: 55 w= 1.9999999548815364 loss= 0.0
    Progress: 56 w= 1.9999999666433785 loss= 0.0
    Progress: 57 w= 1.9999999753390494 loss= 0.0
    Progress: 58 w= 1.9999999817678633 loss= 0.0
    Progress: 59 w= 1.9999999865207625 loss= 0.0
    Progress: 60 w= 1.999999990034638 loss= 0.0
    Progress: 61 w= 1.9999999926324883 loss= 0.0
    Progress: 62 w= 1.99999999455311 loss= 0.0
    Progress: 63 w= 1.9999999959730488 loss= 0.0
    Progress: 64 w= 1.9999999970228268 loss= 0.0
    Progress: 65 w= 1.9999999977989402 loss= 0.0
    Progress: 66 w= 1.9999999983727301 loss= 0.0
    Progress: 67 w= 1.9999999987969397 loss= 0.0
    Progress: 68 w= 1.999999999110563 loss= 0.0
    Progress: 69 w= 1.9999999993424284 loss= 0.0
    Progress: 70 w= 1.9999999995138495 loss= 0.0
    Progress: 71 w= 1.9999999996405833 loss= 0.0
    Progress: 72 w= 1.999999999734279 loss= 0.0
    Progress: 73 w= 1.9999999998035491 loss= 0.0
    Progress: 74 w= 1.9999999998547615 loss= 0.0
    Progress: 75 w= 1.9999999998926234 loss= 0.0
    Progress: 76 w= 1.9999999999206153 loss= 0.0
    Progress: 77 w= 1.9999999999413098 loss= 0.0
    Progress: 78 w= 1.9999999999566096 loss= 0.0
    Progress: 79 w= 1.9999999999679208 loss= 0.0
    Progress: 80 w= 1.9999999999762834 loss= 0.0
    Progress: 81 w= 1.999999999982466 loss= 0.0
    Progress: 82 w= 1.9999999999870368 loss= 0.0
    Progress: 83 w= 1.999999999990416 loss= 0.0
    Progress: 84 w= 1.9999999999929146 loss= 0.0
    Progress: 85 w= 1.9999999999947617 loss= 0.0
    Progress: 86 w= 1.9999999999961273 loss= 0.0
    Progress: 87 w= 1.999999999997137 loss= 0.0
    Progress: 88 w= 1.9999999999978835 loss= 0.0
    Progress: 89 w= 1.9999999999984353 loss= 0.0
    Progress: 90 w= 1.9999999999988431 loss= 0.0
    Progress: 91 w= 1.9999999999991447 loss= 0.0
    Progress: 92 w= 1.9999999999993676 loss= 0.0
    Progress: 93 w= 1.9999999999995324 loss= 0.0
    Progress: 94 w= 1.9999999999996543 loss= 0.0
    Progress: 95 w= 1.9999999999997444 loss= 0.0
    Progress: 96 w= 1.999999999999811 loss= 0.0
    Progress: 97 w= 1.9999999999998603 loss= 0.0
    Progress: 98 w= 1.9999999999998967 loss= 0.0
    Progress: 99 w= 1.9999999999999236 loss= 0.0
    Predict (after training):  4 7.9999999999996945
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102

    在这里插入图片描述

    存在的问题:
    梯度下降: 计算梯度可以并行,学习器的性能较低,效率最高
    随机梯度下降: 计算梯度不可以并行,学习器的性能较好,但时间复杂度较高

    解决办法:
    折中: Batch/Mini - batch 批量的随机梯度下降

  • 相关阅读:
    java毕业设计软件缺陷管理系统源码+lw文档+mybatis+系统+mysql数据库+调试
    如何用Python获取网页指定内容
    146.LRU缓存--hash-双链表
    保障邮箱安全,验证码独有四个优势
    贴片电阻具有哪些特性?
    一维差分数组
    雷达编程实战之静态杂波滤除与到达角估计
    【架构艺术】(零) 环境搭建
    前几天,小灰去贵州了
    idea Debug 模式下tomcat无法启用
  • 原文地址:https://blog.csdn.net/qq_44948213/article/details/126372585