由于复现某些论文中的代码时,使用正常的方法跑,显存不够。了解到这个方法是牺牲时间来降低显存,使用完之后,果然可以跑起来,而且显存降低了好多。那个代码至少30G显存才可能跑起来,使用完之后,不到9个G。
写这个博客希望可以帮助到一些有需要的人。
我们使用pytorch训练模型的时候主要有四部分消耗显存。
模型的现存之所以那么大,其中原因之一就是计算梯度时,模型会把所有前向传播的中间激活值都保存下来,这非常消耗显存,这样的好处是,需要那个中间激活值时,可以直接用,就不需要再次计算,节省了时间。
Checkpointing采取的策略是:保留一部分中间激活值,其余部分丢弃,如果用到的中间激活值没有的话,就重新计算,这样大大节省了显存,但是增加了时间。
for cascade in self.cascades:
if is_training:
kspace_pred = checkpoint.checkpoint(cascade, x1, x2)
else:
kspace_pred = cascade(x1, x2)
# cascade:网络
# x1:网络的参数1
# x2:网络的参数2
上述是在训练的时候使用checkpoint技术,在验证和测试的时候不使用。
checkpoint放在你进入网络,开始迭代的时候。
如果,你使用的时候遇到下面这个警告。
警告:UserWarning: None of the inputs have requires_grad=True.
可能的解决办法之一:
你把所有的 requires_grad设置为True。
可能的解决办法之一:
你在测试或者验证的时候也使用了checkpoint,因为测试的或者验证的时候,不需要梯度传播,也就引发了这个警告。
你可以不用管,结果应该是一样的。
如果你不想看到警告,你就设置个判断,测试和验证的时候不使用checkpoint,仅在训练的时候使用。