• 优化算法 - Adadelta


    Adadelta

    Adadelta是AdaGrad的另一种变体,主要区别在于前者减少了学习率适应坐标的数量。此外,广义上Adadelta被称为没有学习率,因为它使用变化量作为未来变化的校准

    1 - Adadelta算法

    2 - 代码实现

    Adadelta需要为每个变量维护两个状态变量,即 s t 和 Δ x t s_t和\Delta x_t stΔxt

    %matplotlib inline
    import torch
    from d2l import torch as d2l
    
    • 1
    • 2
    • 3
    def init_adadelta_states(feature_dim):
        s_w,s_b = torch.zeros((feature_dim,1)),torch.zeros(1)
        delta_w, delta_b = torch.zeros((feature_dim, 1)), torch.zeros(1)
        return ((s_w,delta_w),(s_b,delta_b))
    
    def adadelta(params,states,hyperparams):
        rho,eps = hyperparams['rho'],1e-5
        for p,(s,delta) in zip(params,states):
            with torch.no_grad():
                #  In-placeupdatesvia[:]
                s[:] = rho * s + (1 - rho) * torch.square(p.grad)
                g = (torch.sqrt(delta + eps) / torch.sqrt(s + eps)) * p.grad
                p[:] -= g
                delta[:] = rho * delta + (1 - rho) * g * g
            p.grad.data.zero_()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15

    对于每次参数更新,选择ρ = 0.9相当于10个半衰期,由此我们得到

    data_iter,feature_dim = d2l.get_data_ch11(batch_size=10)
    d2l.train_ch11(adadelta,init_adadelta_states(feature_dim),{'rho':0.9},data_iter,feature_dim);
    
    • 1
    • 2
    loss: 0.243, 0.009 sec/epoch
    
    • 1

    [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-AFNLhqPW-1663327838192)(https://yingziimage.oss-cn-beijing.aliyuncs.com/img/202209161924784.svg)]

    为了简洁实现,我们只需使用Trainer类中的adadelta算法

    trainer = torch.optim.Adadelta
    d2l.train_concise_ch11(trainer,{'rho':0.9},data_iter)
    
    • 1
    • 2
    loss: 0.243, 0.007 sec/epoch
    
    • 1

    [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-s9E9EN42-1663327838192)(https://yingziimage.oss-cn-beijing.aliyuncs.com/img/202209161924785.svg)]

    3 - 小结

    • Adadelta没有学习率参数。相反,它使用参数本身的学习率来调整学习率
    • Adadelta需要两个状态变量来存储梯度 的二阶导数和参数的变化
    • Adadelta使用泄露的平均值来保持对适当统计数据的运行估计
  • 相关阅读:
    APP自动化测试-9.Appium设备交互与模拟器控制
    并查集 rank 的优化
    操作系统基础知识1
    SAP 采购订单抬头屏幕增强(SMOD)
    antd form+upload把上传后的文件删除后,表单校验失效了
    睿趣科技:抖音店铺名字怎么更吸引人
    k8s 1.28安装
    python多线程
    JavaSE 第六章 面向对象基础-中(封装)
    Java比较两个日期的间隔天数(案例详解)
  • 原文地址:https://blog.csdn.net/mynameisgt/article/details/126896472