码农知识堂 - 1000bd
  •   Python
  •   PHP
  •   JS/TS
  •   JAVA
  •   C/C++
  •   C#
  •   GO
  •   Kotlin
  •   Swift
  • Pytorch梯度累积实现


    前言

    主要用于解决显卡内存不足的问题。
    梯度累积可以使用单卡实现增大batchsize的效果

    梯度累积原理

    按顺序执行Mini-Batch,同时对梯度进行累积,累积的结果在最后一个Mini-Batch计算后求平均更新模型变量。
    a c c u m u l a t e d = ∑ i = 0 N g r a d i \color{green}accumulated=\sum_{i=0}^{N}grad_{i} accumulated=i=0∑N​gradi​

    梯度累积是一种训练神经网络的数据Sample样本按Batch拆分为几个小Batch的方式,然后按顺序计算。
    在不更新模型变量的时候,实际上是把原来的数据Batch分成几个小的Mini-Batch,每个step中使用的样本实际上是更小的数据集。
    在N个step内不更新变量,使所有Mini-Batch使用相同的模型变量来计算梯度,以确保计算出来得到相同的梯度和权重信息,算法上等价于使用原来没有切分的Batch size大小一样。即:
    θ i = θ i − 1 − l r ∗ ∑ i = 0 N g r a d i \color{green}\theta _{i}=\theta _{i-1}-lr*\sum_{i=0}^{N}grad_{i} θi​=θi−1​−lr∗i=0∑N​gradi​
    在这里插入图片描述

    代码实现

    不加梯度累加的代码

    for i, (images, labels) in enumerate(train_data):
        # 1. forwared 前向计算
        outputs = model(images)
        loss = criterion(outputs, labels)
    
        # 2. backward 反向传播计算梯度
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    加了梯度累加的代码

    # 梯度累加参数
    accumulation_steps = 4
    
    
    for i, (images, labels) in enumerate(train_data):
        # 1. forwared 前向计算
        outputs = model(imgaes)
        loss = criterion(outputs, labels)
    
        # 2.1 loss regularization loss正则化
        loss += loss / accumulation_steps
    
        # 2.2 backward propagation 反向传播计算梯度
        loss.backward()
    
        # 3. update parameters of net
        if ((i+1) % accumulation)==0:
            # optimizer the net
            optimizer.step()
            optimizer.zero_grad() # reset grdient
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21

    代码中设置accumulation_steps = 4,意思就是变相扩大batch_size四倍。因为代码中每隔4次迭代才清空梯度,更新参数。
    loss = loss/accumulation_steps,梯度累加了四次,那就要取平均除以4。同时,因为累计了4个batch,那学习率也应该扩大4倍,让更新的步子跨大点。
    参考博客:1、pytorch骚操作之梯度累加,变相增大batch size
    2、如何通透理解梯度累加

  • 相关阅读:
    51单片机数码管交通灯(51单片机实训项目)
    Java基础(十九)Map
    Dubbo 服务注册与启动源码解析
    FastDFS——从入门到入土(上)
    nginx实现灰度上线(InsCode AI 创作助手)
    springBoot Event实现异步消费机制
    人工智能的发展现状,AI将如何改变IT行业,哪些职业将最先失业
    MPLS VPN跨域C1方案 RR反射器
    六月集训(27) 图
    (02)Cartographer源码无死角解析-(27) 数据订阅、变换、排序、转发→总体复盘
  • 原文地址:https://blog.csdn.net/fcxgfdjy/article/details/133294760
  • 最新文章
  • 攻防演习之三天拿下官网站群
    数据安全治理学习——前期安全规划和安全管理体系建设
    企业安全 | 企业内一次钓鱼演练准备过程
    内网渗透测试 | Kerberos协议及其部分攻击手法
    0day的产生 | 不懂代码的"代码审计"
    安装scrcpy-client模块av模块异常,环境问题解决方案
    leetcode hot100【LeetCode 279. 完全平方数】java实现
    OpenWrt下安装Mosquitto
    AnatoMask论文汇总
    【AI日记】24.11.01 LangChain、openai api和github copilot
  • 热门文章
  • 十款代码表白小特效 一个比一个浪漫 赶紧收藏起来吧!!!
    奉劝各位学弟学妹们,该打造你的技术影响力了!
    五年了,我在 CSDN 的两个一百万。
    Java俄罗斯方块,老程序员花了一个周末,连接中学年代!
    面试官都震惊,你这网络基础可以啊!
    你真的会用百度吗?我不信 — 那些不为人知的搜索引擎语法
    心情不好的时候,用 Python 画棵樱花树送给自己吧
    通宵一晚做出来的一款类似CS的第一人称射击游戏Demo!原来做游戏也不是很难,连憨憨学妹都学会了!
    13 万字 C 语言从入门到精通保姆级教程2021 年版
    10行代码集2000张美女图,Python爬虫120例,再上征途
Copyright © 2022 侵权请联系2656653265@qq.com    京ICP备2022015340号-1
正则表达式工具 cron表达式工具 密码生成工具

京公网安备 11010502049817号