• 有关PyTorch中Checkpoint的原理、实现和问题


    有关PyTorch中Checkpoint的原理、实现和问题

    一、动机

    ​ 由于复现某些论文中的代码时,使用正常的方法跑,显存不够。了解到这个方法是牺牲时间来降低显存,使用完之后,果然可以跑起来,而且显存降低了好多。那个代码至少30G显存才可能跑起来,使用完之后,不到9个G。

    ​ 写这个博客希望可以帮助到一些有需要的人。

    二、原理

    我们使用pytorch训练模型的时候主要有四部分消耗显存。

    • 模型参数
    • 模型参数的梯度
    • 优化器状态
    • 中间激活值

    模型的现存之所以那么大,其中原因之一就是计算梯度时,模型会把所有前向传播的中间激活值都保存下来,这非常消耗显存,这样的好处是,需要那个中间激活值时,可以直接用,就不需要再次计算,节省了时间。

    Checkpointing采取的策略是:保留一部分中间激活值,其余部分丢弃,如果用到的中间激活值没有的话,就重新计算,这样大大节省了显存,但是增加了时间。

    三、实现

    for cascade in self.cascades:
         if is_training:
            kspace_pred = checkpoint.checkpoint(cascade, x1, x2)
         else:
            kspace_pred = cascade(x1, x2)
            
    # cascade:网络
    # x1:网络的参数1
    # x2:网络的参数2
    

    上述是在训练的时候使用checkpoint技术,在验证和测试的时候不使用。

    checkpoint放在你进入网络,开始迭代的时候。

    四、问题

    如果,你使用的时候遇到下面这个警告。

    警告:UserWarning: None of the inputs have requires_grad=True.

    可能的解决办法之一:

    你把所有的 requires_grad设置为True。

    可能的解决办法之一:

    你在测试或者验证的时候也使用了checkpoint,因为测试的或者验证的时候,不需要梯度传播,也就引发了这个警告。

    你可以不用管,结果应该是一样的。

    如果你不想看到警告,你就设置个判断,测试和验证的时候不使用checkpoint,仅在训练的时候使用。

    参考文章

    • https://blog.csdn.net/Solo95/article/details/131606918?s
    • https://blog.csdn.net/Shirelle_/article/details/137868196
    • https://zhuanlan.zhihu.com/p/424512257
    • https://blog.csdn.net/P_LarT/article/details/122521212
  • 相关阅读:
    高济健康:数字化科技创新与新零售碰撞 助推医疗产业优化升级
    测试/开发程序员的成长路线,全局思考问题的问题......
    接口自动化测试难点:数据库验证解决方案
    尚好房 04_服务拆分
    YoloV8训练自己的模型 && Pycharm Remote Development
    Day31|贪心算法1
    机器学习 sklearn数据集
    小鼠血清白蛋白包裹四氧化三铁纳米粒
    electron-vue初始化项目到打包运行
    docker启动mysql实例之后,docker ps命令查询不到
  • 原文地址:https://blog.csdn.net/lihaiyuan_0324/article/details/139299374