• pytorch优化器设置


    深度学习训练过程中学习率的大小十分重要。学习率过低会导致学习太慢,学习率过高会导致难以收敛。通常情况下,初始学习率会比较大,后来逐渐缩小学习率。

    通常情况下模型优化器设置

    首先定义两层全连接层模型

    1. import torch
    2. from torch import nn
    3. class Net(nn.Module):
    4. def __init__(self):
    5. super(Net, self).__init__()
    6. self.layer1 = nn.Linear(10, 2)
    7. self.layer2 = nn.Linear(2, 10)
    8. def forward(self, input):
    9. return self.layer2(self.layer1(input))

     神经网络的执行步骤。首先神经网络进过前向传播,这是神经网络框架会搭建好计算图(这里会保存操作和对应参与计算的张量,因为在根据计算图计算梯度时需要这些信息)。然后是误差反向传播,loss.backward() ,这时会计算梯度信息。最后根据梯度信息,更新参数。

    1. loss.backward()
    2. optimizer.step()
    3. optimizer.zero_grad()

     optimizer.zero_grad() 是将这一轮的梯度清零,防止影响下一轮参数的更新。这里曾问过面试的问题:什么时候不使用这一步进行清零。

    1. model = Net()
    2. # 只传入想要训练层的参数。其他未传入的参数不参与更新
    3. optimizer_Adam = torch.optim.Adam(model.parameters(), lr=0.1)

    model.parameters()会返回模型的所有参数 

    只训练模型的部分参数

    也就是说只传入模型待优化的参数,为传入的参数不参与更新。

    1. model = Net()
    2. # 只传入待优化的参数
    3. optimizer_Adam = torch.optim.Adam(model.layer1.parameters(), lr=0.1)

     不同部分设置不同的学习率

    1. params_dict = [{'params': model.layer1.parameters(), 'lr': 0.01},
    2. {'params': model.layer2.parameters(), 'lr': 0.001}]
    3. optimizer = torch.optim.Adam(params_dict)

    动态修改学习率

    优化器的param_group属性

    1. -param_groups
    2. -0(dict) # 第一组参数
    3. params: # 维护要更新的参数
    4. lr: # 该组参数的学习率
    5. betas:
    6. eps: # 该组参数的学习率最小值
    7. weight_decay: # 该组参数的权重衰减系数
    8. amsgrad:
    9. -1(dict) # 第二组参数
    10. -2(dict) # 第三组参数

     parm_group是一个列表,其中每个元素都是一个字典

    1. model = Net() # 生成网络
    2. optimizer = torch.optim.Adam(model.parameters(), lr=0.1) # 生成优化器
    3. for epoch in range(100): # 假设迭代100个epoch
    4. if epoch % 5 == 0: # 每迭代5次,更新一次学习率
    5. for params in optimizer.param_groups:
    6. # 遍历Optimizer中的每一组参数,将该组参数的学习率 * 0.9
    7. params['lr'] *= 0.9

  • 相关阅读:
    JavaWeb中,web应用的上下文路径解读
    celery介绍与使用
    gitHub不能用密码推送了,必须要使用令牌
    在 Next.js 中实现用户授权
    图、图的遍历、最小生成树、最短路径
    抄写Linux源码(Day17:你的键盘是什么时候生效的?)
    【单元测试】--维护和改进单元测试
    Vue3中runtime-dom的实现-详细步骤
    Flink cdc 2.3.0 日前发布,支持众多新特性
    选择文档管理系统时的 5 个关键考虑因素
  • 原文地址:https://blog.csdn.net/qq_40107571/article/details/126014057