穷举思路: 全部找一遍,找出最优点
存在的问题:
分治思路: 分成四份,找16个点,找出其中比较小的块,再分成四份,找点,即局部最优点
存在的问题:
import matplotlib.pyplot as plt
# 模型
def forward(x):
return x * w
# 计算mse
def cost(xs, ys):
cost = 0
for x, y in zip(xs, ys):
y_pred = forward(x)
cost += (y_pred - y) ** 2
return cost / len(xs)
# 计算梯度
def gradient(xs, ys):
grad = 0
for x, y in zip(xs, ys):
grad += 2 * x * (x*w - y)
return grad / len(xs)
# 数据集
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
# 权重
w = 1.0
r = 0.01
print('Predict (before training)', 4, forward(4))
epoch_list = [] # 轮次列表
cost_list = [] # mse列表
for epoch in range(100):
cost_val = cost(x_data, y_data)
grad_val = gradient(x_data, y_data)
w = w - r * grad_val
print('Epoch:', epoch, 'w=', w, 'cost=', round(cost_val,3))
epoch_list.append(epoch)
cost_list.append(cost_val)
print('Predict (after training)', 4, forward(4))
# 绘图
plt.plot(epoch_list, cost_list)
plt.xlabel('Epoch')
plt.ylabel('Cost')
plt.show()
输出:
Predict (before training 4 4.0
Epoch: 0 w= 1.0933333333333333 cost= 4.667
Epoch: 1 w= 1.1779555555555554 cost= 3.836
Epoch: 2 w= 1.2546797037037036 cost= 3.154
Epoch: 3 w= 1.3242429313580246 cost= 2.592
Epoch: 4 w= 1.3873135910979424 cost= 2.131
Epoch: 5 w= 1.4444976559288012 cost= 1.752
Epoch: 6 w= 1.4963445413754464 cost= 1.44
Epoch: 7 w= 1.5433523841804047 cost= 1.184
Epoch: 8 w= 1.5859728283235668 cost= 0.973
Epoch: 9 w= 1.6246153643467005 cost= 0.8
Epoch: 10 w= 1.659651263674342 cost= 0.658
Epoch: 11 w= 1.6914171457314033 cost= 0.541
Epoch: 12 w= 1.7202182121298057 cost= 0.444
Epoch: 13 w= 1.7463311789976905 cost= 0.365
Epoch: 14 w= 1.7700069356245727 cost= 0.3
Epoch: 15 w= 1.7914729549662791 cost= 0.247
Epoch: 16 w= 1.8109354791694263 cost= 0.203
Epoch: 17 w= 1.8285815011136133 cost= 0.167
Epoch: 18 w= 1.8445805610096762 cost= 0.137
Epoch: 19 w= 1.8590863753154396 cost= 0.113
Epoch: 20 w= 1.872238313619332 cost= 0.093
Epoch: 21 w= 1.8841627376815275 cost= 0.076
Epoch: 22 w= 1.8949742154979183 cost= 0.063
Epoch: 23 w= 1.904776622051446 cost= 0.051
Epoch: 24 w= 1.9136641373266443 cost= 0.042
Epoch: 25 w= 1.9217221511761575 cost= 0.035
Epoch: 26 w= 1.9290280837330496 cost= 0.029
Epoch: 27 w= 1.9356521292512983 cost= 0.024
Epoch: 28 w= 1.9416579305211772 cost= 0.019
Epoch: 29 w= 1.9471031903392007 cost= 0.016
Epoch: 30 w= 1.952040225907542 cost= 0.013
Epoch: 31 w= 1.9565164714895047 cost= 0.011
Epoch: 32 w= 1.9605749341504843 cost= 0.009
Epoch: 33 w= 1.9642546069631057 cost= 0.007
Epoch: 34 w= 1.9675908436465492 cost= 0.006
Epoch: 35 w= 1.970615698239538 cost= 0.005
Epoch: 36 w= 1.9733582330705144 cost= 0.004
Epoch: 37 w= 1.975844797983933 cost= 0.003
Epoch: 38 w= 1.9780992835054327 cost= 0.003
Epoch: 39 w= 1.980143350378259 cost= 0.002
Epoch: 40 w= 1.9819966376762883 cost= 0.002
Epoch: 41 w= 1.983676951493168 cost= 0.002
Epoch: 42 w= 1.9852004360204722 cost= 0.001
Epoch: 43 w= 1.9865817286585614 cost= 0.001
Epoch: 44 w= 1.987834100650429 cost= 0.001
Epoch: 45 w= 1.9889695845897222 cost= 0.001
Epoch: 46 w= 1.9899990900280147 cost= 0.001
Epoch: 47 w= 1.9909325082920666 cost= 0.0
Epoch: 48 w= 1.9917788075181404 cost= 0.0
Epoch: 49 w= 1.9925461188164473 cost= 0.0
Epoch: 50 w= 1.9932418143935788 cost= 0.0
Epoch: 51 w= 1.9938725783835114 cost= 0.0
Epoch: 52 w= 1.994444471067717 cost= 0.0
Epoch: 53 w= 1.9949629871013967 cost= 0.0
Epoch: 54 w= 1.9954331083052663 cost= 0.0
Epoch: 55 w= 1.9958593515301082 cost= 0.0
Epoch: 56 w= 1.9962458120539648 cost= 0.0
Epoch: 57 w= 1.9965962029289281 cost= 0.0
Epoch: 58 w= 1.9969138906555615 cost= 0.0
Epoch: 59 w= 1.997201927527709 cost= 0.0
Epoch: 60 w= 1.9974630809584561 cost= 0.0
Epoch: 61 w= 1.9976998600690001 cost= 0.0
Epoch: 62 w= 1.9979145397958935 cost= 0.0
Epoch: 63 w= 1.9981091827482769 cost= 0.0
Epoch: 64 w= 1.9982856590251044 cost= 0.0
Epoch: 65 w= 1.9984456641827613 cost= 0.0
Epoch: 66 w= 1.9985907355257035 cost= 0.0
Epoch: 67 w= 1.9987222668766378 cost= 0.0
Epoch: 68 w= 1.9988415219681517 cost= 0.0
Epoch: 69 w= 1.9989496465844576 cost= 0.0
Epoch: 70 w= 1.9990476795699081 cost= 0.0
Epoch: 71 w= 1.9991365628100501 cost= 0.0
Epoch: 72 w= 1.999217150281112 cost= 0.0
Epoch: 73 w= 1.999290216254875 cost= 0.0
Epoch: 74 w= 1.9993564627377531 cost= 0.0
Epoch: 75 w= 1.9994165262155628 cost= 0.0
Epoch: 76 w= 1.999470983768777 cost= 0.0
Epoch: 77 w= 1.9995203586170245 cost= 0.0
Epoch: 78 w= 1.9995651251461022 cost= 0.0
Epoch: 79 w= 1.9996057134657994 cost= 0.0
Epoch: 80 w= 1.9996425135423248 cost= 0.0
Epoch: 81 w= 1.999675878945041 cost= 0.0
Epoch: 82 w= 1.999706130243504 cost= 0.0
Epoch: 83 w= 1.9997335580874436 cost= 0.0
Epoch: 84 w= 1.9997584259992822 cost= 0.0
Epoch: 85 w= 1.9997809729060159 cost= 0.0
Epoch: 86 w= 1.9998014154347876 cost= 0.0
Epoch: 87 w= 1.9998199499942075 cost= 0.0
Epoch: 88 w= 1.9998367546614149 cost= 0.0
Epoch: 89 w= 1.9998519908930161 cost= 0.0
Epoch: 90 w= 1.9998658050763347 cost= 0.0
Epoch: 91 w= 1.9998783299358769 cost= 0.0
Epoch: 92 w= 1.9998896858085284 cost= 0.0
Epoch: 93 w= 1.9998999817997325 cost= 0.0
Epoch: 94 w= 1.9999093168317574 cost= 0.0
Epoch: 95 w= 1.9999177805941268 cost= 0.0
Epoch: 96 w= 1.9999254544053418 cost= 0.0
Epoch: 97 w= 1.9999324119941766 cost= 0.0
Epoch: 98 w= 1.9999387202080534 cost= 0.0
Epoch: 99 w= 1.9999444396553017 cost= 0.0
Predict (after training) 4 7.999777758621207
如果绘制图不是很光滑,可以使用指数加权均值来绘图
训练失败可能是学习率设置太大了
使用较多的是随机梯度下降:随机选一个样本损失对权重求导然后进行更新
import matplotlib.pyplot as plt
def forward(x):
return x * w
def loss(x, y):
return (forward(x) - y) ** 2
def gradient(x, y):
return 2 * x * (forward(x) - y)
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
w = 1.0
r = 0.01
print('Predict (before training): ', 4, forward(4))
epoch_list = []
loss_list = []
for epoch in range(100):
for x, y in zip(x_data, y_data):
grad = gradient(x, y)
w -= r * grad
l = loss(x, y)
# print('\tgrad:', x, y, grad)
print('Progress:', epoch, 'w=', w, 'loss=', round(l, 6))
epoch_list.append(epoch)
loss_list.append(l)
print('Predict (after training): ', 4, forward(4))
plt.plot(epoch_list, loss_list)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.show()
输出:
Predict (before training): 4 4.0
Progress: 0 w= 1.260688 loss= 4.91924
Progress: 1 w= 1.453417766656 loss= 2.688769
Progress: 2 w= 1.5959051959019805 loss= 1.469633
Progress: 3 w= 1.701247862192685 loss= 0.803276
Progress: 4 w= 1.7791289594933983 loss= 0.439056
Progress: 5 w= 1.836707389300983 loss= 0.23998
Progress: 6 w= 1.8792758133988885 loss= 0.131169
Progress: 7 w= 1.910747160155559 loss= 0.071695
Progress: 8 w= 1.9340143044689266 loss= 0.039187
Progress: 9 w= 1.9512159834655312 loss= 0.021419
Progress: 10 w= 1.9639333911678687 loss= 0.011707
Progress: 11 w= 1.9733355232910992 loss= 0.006399
Progress: 12 w= 1.9802866323953892 loss= 0.003498
Progress: 13 w= 1.9854256707695 loss= 0.001912
Progress: 14 w= 1.9892250235079405 loss= 0.001045
Progress: 15 w= 1.9920339305797026 loss= 0.000571
Progress: 16 w= 1.994110589284741 loss= 0.000312
Progress: 17 w= 1.9956458879852805 loss= 0.000171
Progress: 18 w= 1.9967809527381737 loss= 9.3e-05
Progress: 19 w= 1.9976201197307648 loss= 5.1e-05
Progress: 20 w= 1.998240525958391 loss= 2.8e-05
Progress: 21 w= 1.99869919972735 loss= 1.5e-05
Progress: 22 w= 1.9990383027488265 loss= 8e-06
Progress: 23 w= 1.9992890056818404 loss= 5e-06
Progress: 24 w= 1.999474353368653 loss= 2e-06
Progress: 25 w= 1.9996113831376856 loss= 1e-06
Progress: 26 w= 1.9997126908902887 loss= 1e-06
Progress: 27 w= 1.9997875889274812 loss= 0.0
Progress: 28 w= 1.9998429619451539 loss= 0.0
Progress: 29 w= 1.9998838998815958 loss= 0.0
Progress: 30 w= 1.9999141657892625 loss= 0.0
Progress: 31 w= 1.9999365417379913 loss= 0.0
Progress: 32 w= 1.9999530845453979 loss= 0.0
Progress: 33 w= 1.9999653148414271 loss= 0.0
Progress: 34 w= 1.999974356846045 loss= 0.0
Progress: 35 w= 1.9999810417085633 loss= 0.0
Progress: 36 w= 1.9999859839076413 loss= 0.0
Progress: 37 w= 1.9999896377347262 loss= 0.0
Progress: 38 w= 1.999992339052936 loss= 0.0
Progress: 39 w= 1.9999943361699042 loss= 0.0
Progress: 40 w= 1.9999958126624442 loss= 0.0
Progress: 41 w= 1.999996904251097 loss= 0.0
Progress: 42 w= 1.999997711275687 loss= 0.0
Progress: 43 w= 1.9999983079186507 loss= 0.0
Progress: 44 w= 1.9999987490239537 loss= 0.0
Progress: 45 w= 1.9999990751383971 loss= 0.0
Progress: 46 w= 1.9999993162387186 loss= 0.0
Progress: 47 w= 1.9999994944870796 loss= 0.0
Progress: 48 w= 1.9999996262682318 loss= 0.0
Progress: 49 w= 1.999999723695619 loss= 0.0
Progress: 50 w= 1.9999997957248556 loss= 0.0
Progress: 51 w= 1.9999998489769344 loss= 0.0
Progress: 52 w= 1.9999998883468353 loss= 0.0
Progress: 53 w= 1.9999999174534755 loss= 0.0
Progress: 54 w= 1.999999938972364 loss= 0.0
Progress: 55 w= 1.9999999548815364 loss= 0.0
Progress: 56 w= 1.9999999666433785 loss= 0.0
Progress: 57 w= 1.9999999753390494 loss= 0.0
Progress: 58 w= 1.9999999817678633 loss= 0.0
Progress: 59 w= 1.9999999865207625 loss= 0.0
Progress: 60 w= 1.999999990034638 loss= 0.0
Progress: 61 w= 1.9999999926324883 loss= 0.0
Progress: 62 w= 1.99999999455311 loss= 0.0
Progress: 63 w= 1.9999999959730488 loss= 0.0
Progress: 64 w= 1.9999999970228268 loss= 0.0
Progress: 65 w= 1.9999999977989402 loss= 0.0
Progress: 66 w= 1.9999999983727301 loss= 0.0
Progress: 67 w= 1.9999999987969397 loss= 0.0
Progress: 68 w= 1.999999999110563 loss= 0.0
Progress: 69 w= 1.9999999993424284 loss= 0.0
Progress: 70 w= 1.9999999995138495 loss= 0.0
Progress: 71 w= 1.9999999996405833 loss= 0.0
Progress: 72 w= 1.999999999734279 loss= 0.0
Progress: 73 w= 1.9999999998035491 loss= 0.0
Progress: 74 w= 1.9999999998547615 loss= 0.0
Progress: 75 w= 1.9999999998926234 loss= 0.0
Progress: 76 w= 1.9999999999206153 loss= 0.0
Progress: 77 w= 1.9999999999413098 loss= 0.0
Progress: 78 w= 1.9999999999566096 loss= 0.0
Progress: 79 w= 1.9999999999679208 loss= 0.0
Progress: 80 w= 1.9999999999762834 loss= 0.0
Progress: 81 w= 1.999999999982466 loss= 0.0
Progress: 82 w= 1.9999999999870368 loss= 0.0
Progress: 83 w= 1.999999999990416 loss= 0.0
Progress: 84 w= 1.9999999999929146 loss= 0.0
Progress: 85 w= 1.9999999999947617 loss= 0.0
Progress: 86 w= 1.9999999999961273 loss= 0.0
Progress: 87 w= 1.999999999997137 loss= 0.0
Progress: 88 w= 1.9999999999978835 loss= 0.0
Progress: 89 w= 1.9999999999984353 loss= 0.0
Progress: 90 w= 1.9999999999988431 loss= 0.0
Progress: 91 w= 1.9999999999991447 loss= 0.0
Progress: 92 w= 1.9999999999993676 loss= 0.0
Progress: 93 w= 1.9999999999995324 loss= 0.0
Progress: 94 w= 1.9999999999996543 loss= 0.0
Progress: 95 w= 1.9999999999997444 loss= 0.0
Progress: 96 w= 1.999999999999811 loss= 0.0
Progress: 97 w= 1.9999999999998603 loss= 0.0
Progress: 98 w= 1.9999999999998967 loss= 0.0
Progress: 99 w= 1.9999999999999236 loss= 0.0
Predict (after training): 4 7.9999999999996945
存在的问题:
梯度下降: 计算梯度可以并行,学习器的性能较低,效率最高
随机梯度下降: 计算梯度不可以并行,学习器的性能较好,但时间复杂度较高
解决办法:
折中: Batch/Mini - batch 批量的随机梯度下降