总结
requires_grad=True with torch.enable_grad():# 权重衰退是广泛应用的正则化技术
%matplotlib inline
import torch
from torch import nn
from d2l import torch as d2l
人工数据集
y
=
0.05
+
∑
i
=
1
d
0.01
x
i
+
ϵ
,
w
h
e
r
e
ϵ
∼
η
(
0
,
0.0
1
2
)
,
ϵ
是偏差
y=0.05+\sum\limits^d_{i=1}0.01x_i+\epsilon, where \epsilon\sim \eta(0,0.01^2), \epsilon是偏差
y=0.05+i=1∑d0.01xi+ϵ,whereϵ∼η(0,0.012),ϵ是偏差
# 数据
n_train,n_test,num_inputs,batch_size=20,100,200,5
# true_b=0.05
true_w,true_b=torch.ones((num_inputs,1))*0.01,0.05
# 20个训练数据样本,模型容量大(num_inputs=100)+数据量小=容易发生过拟合
train_data=d2l.synthetic_data(true_w,true_b,n_train)
train_iter=d2l.load_array(train_data,batch_size)
# 5个测试数据样本,实际上是验证集
test_data=d2l.synthetic_data(true_w,true_b,n_test)
test_iter=d2l.load_array(test_data,batch_size,is_train=False)
# 参数初始化
def init_params():
# !requires_grad=True作用,需要对这个参数w计算梯度
w=torch.normal(0,1,size=(num_inputs,1),requires_grad=True)
# requires_grad=True作用,需要对这个参数w计算梯度
b=torch.zeros(1,requires_grad=True)
return [w,b]
# 惩罚项
def l2_penalty(w):
return torch.sum(w.pow(2))/2
# 训练
def train(lambd):
w,b=init_params()
# 模型和损失函数
net,loss=lambda X:d2l.linreg(X,w,b),d2l.squared_loss
# 超参数
num_epochs,lr=100,0.003
# 展示 xlabel, x轴代表什么;ylabel,y轴代表什么;yscale,y轴缩放类型;xlim,x轴范围限制;legend,铭文,图例
animator=d2l.Animator(xlabel='epochs',ylabel='loss',
yscale='log',
xlim=[5, num_epochs],
legend=['train', 'test'])
# 开始训练
for epoch in range(num_epochs):
for X,y in train_iter:
# 上下文管理器
# 先调用with后面的`troch.enable_grad()`的`__enter__()`方法,执行完with内部再调用troch.enable_grad()`的` __exit__()`
with torch.enable_grad():
# 增加了L2范数惩罚项。广播机制,w被复制batchs_size次
l=loss(net(X),y)+lambd*l2_penalty(w)
# sum()不影响梯度,因为是梯度是求偏导
# !书写损失函数时就需要梯度,后向传播之前就关闭梯度
l.sum().backward()
d2l.sgd([w,b],lr,batch_size)
# 运行完5个epoch之后展示1次
if(epoch+1)%5==0:
animator.add(epoch + 1,
(d2l.evaluate_loss(net, train_iter, loss),
d2l.evaluate_loss(net, test_iter, loss)))
# 均方损失
print('w的L2范数:',torch.norm(w).item())
# 无正则
train(lambd=0)
# 有正则项且lambda=3
train(lambd=3)
# 有正则项且lambda=10
train(lambd=10)
w的L2范数: 0.02175123244524002
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Zv80bQYp-1662210118936)(output_6_1.svg)]](https://1000bd.com/contentImg/2023/10/31/052758350.png)
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Ndb2YX0H-1662210118939)(output_6_2.svg)]](https://1000bd.com/contentImg/2023/10/31/052758368.png)
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-g6UcJPbB-1662210118940)(output_6_3.svg)]](https://1000bd.com/contentImg/2023/10/31/052758367.png)