今天跑一个模型的时候,需要加载部分预训练模型的参数,这期间遇到使用torch.load 忽略了 map_location参数 默认gpu,这导致这个变量分配的显存 不释放 然后占用大量资源 gpu资源不能很好的利用。
比如我们一般我们会使用下面方式进行加载预训练参数 到 自身写的模型中:
from transformers import RobertaForMultipleChoice
import torch
model = RobertaForMultipleChoice.from_pretrained("roberta-large")
pretrained_model = torch.load("./checkpoints/txt_matching_e1.pth").roberta
pretrained_dict = pretrained_model.state_dict()
model_dict = model.roberta.state_dict()
# pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} #去除一些不需要的参数
model_dict.update(pretrained_dict)
model.roberta.load_state_dict(model_dict)
1. 当我们没有使用参数时候 load 默认使用了一块显卡然后报错
RuntimeError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 3; 10.76 GiB total capacity; 350.54 MiB already allocated; 21.81 MiB free; 356.00 MiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
torch load 之前gpu使用
torch load 之后 outof memory 了 并且也不释放
2. 当我们没有使用参数时候 load 默认使用了一块显卡然后报错
当我试试指定显卡 gpu会使用2841
pretrained_model = torch.load(“./checkpoints/txt_matching_e1.pth”,map_location=‘cuda:0’).roberta!
(model 直接cuda 的gpu 占用情况)
然后把这里面参数给model,并且model也是用cuda0 然后gpu使用4193
你可能会想model是不是model load预训练参数之后 就这么大了 那么load 和 load 参数后model 用不同gpu看看。
原理:cuda的内存管理机制
参考解释博客:https://blog.csdn.net/qq_43827595/article/details/115722953
解决方案:
1. 不占用显存的使用方法,使用cpu 然后在del 用gc释放内存
model = RobertaForMultipleChoice.from_pretrained("roberta-large")
pretrained_model = torch.load("./checkpoints/txt_matching_e1.pth",map_location='cpu').roberta
pretrained_dict = pretrained_model.state_dict()
model_dict = model.roberta.state_dict()
model_dict.update(pretrained_dict)
model.roberta.load_state_dict(model_dict)
del pretrained_model
import gc
gc.collect()
2. 合理使用, torch.cuda.empty_cache()
这个需要了解一下python的内存管理,引用机制。
比如我pretrain_model 给model直接加载参数,model和pretrain_model 都在cuda:0上,使用torch.cuda.empty_cache() 不能释放pretrain_model 的显存。
当 我把model 放到 cuda:1上(本来在cuda:0),这时候用torch.cuda.empty_cache() 可以释放。