码农知识堂 - 1000bd
  •   Python
  •   PHP
  •   JS/TS
  •   JAVA
  •   C/C++
  •   C#
  •   GO
  •   Kotlin
  •   Swift
  • 有关PyTorch中Checkpoint的原理、实现和问题


    有关PyTorch中Checkpoint的原理、实现和问题

    一、动机

    ​ 由于复现某些论文中的代码时,使用正常的方法跑,显存不够。了解到这个方法是牺牲时间来降低显存,使用完之后,果然可以跑起来,而且显存降低了好多。那个代码至少30G显存才可能跑起来,使用完之后,不到9个G。

    ​ 写这个博客希望可以帮助到一些有需要的人。

    二、原理

    我们使用pytorch训练模型的时候主要有四部分消耗显存。

    • 模型参数
    • 模型参数的梯度
    • 优化器状态
    • 中间激活值

    模型的现存之所以那么大,其中原因之一就是计算梯度时,模型会把所有前向传播的中间激活值都保存下来,这非常消耗显存,这样的好处是,需要那个中间激活值时,可以直接用,就不需要再次计算,节省了时间。

    Checkpointing采取的策略是:保留一部分中间激活值,其余部分丢弃,如果用到的中间激活值没有的话,就重新计算,这样大大节省了显存,但是增加了时间。

    三、实现

    for cascade in self.cascades:
         if is_training:
            kspace_pred = checkpoint.checkpoint(cascade, x1, x2)
         else:
            kspace_pred = cascade(x1, x2)
            
    # cascade:网络
    # x1:网络的参数1
    # x2:网络的参数2
    

    上述是在训练的时候使用checkpoint技术,在验证和测试的时候不使用。

    checkpoint放在你进入网络,开始迭代的时候。

    四、问题

    如果,你使用的时候遇到下面这个警告。

    警告:UserWarning: None of the inputs have requires_grad=True.

    可能的解决办法之一:

    你把所有的 requires_grad设置为True。

    可能的解决办法之一:

    你在测试或者验证的时候也使用了checkpoint,因为测试的或者验证的时候,不需要梯度传播,也就引发了这个警告。

    你可以不用管,结果应该是一样的。

    如果你不想看到警告,你就设置个判断,测试和验证的时候不使用checkpoint,仅在训练的时候使用。

    参考文章

    • https://blog.csdn.net/Solo95/article/details/131606918?s
    • https://blog.csdn.net/Shirelle_/article/details/137868196
    • https://zhuanlan.zhihu.com/p/424512257
    • https://blog.csdn.net/P_LarT/article/details/122521212
  • 相关阅读:
    开箱即用的数据缓存服务|EMQX Cloud 影子服务应用场景解析
    选择题汇总1-2(括号里填的答案都是对的,不用管下面那个答案正确与错误,因为作者懒得删了)
    【Linux从入门到精通】通信 | 共享内存(System V)
    机器学习入门案例(2)之使用逻辑回归预测房子是否能被租出去
    8张图带你全面了解kafka的核心机制
    LiveGBS流媒体平台国标GB/T28181作为下级支持国标级联海康大华宇视华为等第三方国标平台支持对接政务公安内网国标视频平台
    shiro篇---开启常见的注解
    python DevOps
    如何系统地自学 Python
    源码级深度理解 Java SPI
  • 原文地址:https://blog.csdn.net/lihaiyuan_0324/article/details/139299374
  • 最新文章
  • 攻防演习之三天拿下官网站群
    数据安全治理学习——前期安全规划和安全管理体系建设
    企业安全 | 企业内一次钓鱼演练准备过程
    内网渗透测试 | Kerberos协议及其部分攻击手法
    0day的产生 | 不懂代码的"代码审计"
    安装scrcpy-client模块av模块异常,环境问题解决方案
    leetcode hot100【LeetCode 279. 完全平方数】java实现
    OpenWrt下安装Mosquitto
    AnatoMask论文汇总
    【AI日记】24.11.01 LangChain、openai api和github copilot
  • 热门文章
  • 十款代码表白小特效 一个比一个浪漫 赶紧收藏起来吧!!!
    奉劝各位学弟学妹们,该打造你的技术影响力了!
    五年了,我在 CSDN 的两个一百万。
    Java俄罗斯方块,老程序员花了一个周末,连接中学年代!
    面试官都震惊,你这网络基础可以啊!
    你真的会用百度吗?我不信 — 那些不为人知的搜索引擎语法
    心情不好的时候,用 Python 画棵樱花树送给自己吧
    通宵一晚做出来的一款类似CS的第一人称射击游戏Demo!原来做游戏也不是很难,连憨憨学妹都学会了!
    13 万字 C 语言从入门到精通保姆级教程2021 年版
    10行代码集2000张美女图,Python爬虫120例,再上征途
Copyright © 2022 侵权请联系2656653265@qq.com    京ICP备2022015340号-1
正则表达式工具 cron表达式工具 密码生成工具

京公网安备 11010502049817号