• 有关PyTorch中Checkpoint的原理、实现和问题


    有关PyTorch中Checkpoint的原理、实现和问题

    一、动机

    ​ 由于复现某些论文中的代码时,使用正常的方法跑,显存不够。了解到这个方法是牺牲时间来降低显存,使用完之后,果然可以跑起来,而且显存降低了好多。那个代码至少30G显存才可能跑起来,使用完之后,不到9个G。

    ​ 写这个博客希望可以帮助到一些有需要的人。

    二、原理

    我们使用pytorch训练模型的时候主要有四部分消耗显存。

    • 模型参数
    • 模型参数的梯度
    • 优化器状态
    • 中间激活值

    模型的现存之所以那么大,其中原因之一就是计算梯度时,模型会把所有前向传播的中间激活值都保存下来,这非常消耗显存,这样的好处是,需要那个中间激活值时,可以直接用,就不需要再次计算,节省了时间。

    Checkpointing采取的策略是:保留一部分中间激活值,其余部分丢弃,如果用到的中间激活值没有的话,就重新计算,这样大大节省了显存,但是增加了时间。

    三、实现

    for cascade in self.cascades:
         if is_training:
            kspace_pred = checkpoint.checkpoint(cascade, x1, x2)
         else:
            kspace_pred = cascade(x1, x2)
            
    # cascade:网络
    # x1:网络的参数1
    # x2:网络的参数2
    

    上述是在训练的时候使用checkpoint技术,在验证和测试的时候不使用。

    checkpoint放在你进入网络,开始迭代的时候。

    四、问题

    如果,你使用的时候遇到下面这个警告。

    警告:UserWarning: None of the inputs have requires_grad=True.

    可能的解决办法之一:

    你把所有的 requires_grad设置为True。

    可能的解决办法之一:

    你在测试或者验证的时候也使用了checkpoint,因为测试的或者验证的时候,不需要梯度传播,也就引发了这个警告。

    你可以不用管,结果应该是一样的。

    如果你不想看到警告,你就设置个判断,测试和验证的时候不使用checkpoint,仅在训练的时候使用。

    参考文章

    • https://blog.csdn.net/Solo95/article/details/131606918?s
    • https://blog.csdn.net/Shirelle_/article/details/137868196
    • https://zhuanlan.zhihu.com/p/424512257
    • https://blog.csdn.net/P_LarT/article/details/122521212
  • 相关阅读:
    【建议背诵】软考高项考试案例简答题汇总~
    【黑马-SpringCloud技术栈】【03】Eureka注册中心_Ribbon负载均衡
    【无标题】esp8266替代产品
    MFC-GetSystemFirmwareTable获取系统固件表
    python django 小程序点餐源码
    React+Vue相关插件使用的缺陷小合集
    【MMC/SD/SDIO】概述
    Spring的创建与使用
    java毕业设计房地产管理系统登录Mybatis+系统+数据库+调试部署
    C4 数据集基本信息速览
  • 原文地址:https://blog.csdn.net/lihaiyuan_0324/article/details/139299374