torch.load()的作用:从文件加载用torch.save()保存的对象。
格式:torch.load — PyTorch 1.12 documentation
torch.load(f, map_location=None, pickle_module=pickle, **pickle_load_args)
参数解释:
常用使用方式:torch.load — PyTorch 1.12 documentation
- # 常用根据设备,加载Tensor
- >>> torch.load('modelparameters.pth', map_location = device)
-
- # 默认加载方式,使用cpu加载cpu训练得出的模型或者用gpu调用gpu训练的模型
- >>> torch.load('tensors.pt')
-
- # Load all tensors onto the CPU
- # ♥♥♥♥♥♥♥♥♥♥♥♥♥♥♥♥♥♥♥♥♥♥♥♥♥♥♥♥♥♥♥♥♥♥将全部Tensor全部加载到cpu上
- >>> torch.load('tensors.pt', map_location=torch.device('cpu'))
-
- # Load all tensors onto the CPU, using a function
- # 使用一个函数将所有的Tensor加载到CPU上
- >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage)
- # Load all tensors onto GPU 1
-
- # 加载全部Tensor到GPU 1上
- >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1))
-
- # Map tensors from GPU 1 to GPU 0
- #将张量从GPU 1映射到GPU 0
- >>> torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'})
-
- # Load tensor from io.BytesIO object
- # 从io加载张量。BytesIO对象
- >>> with open('tensor.pt', 'rb') as f:
- ... buffer = io.BytesIO(f.read())
- >>> torch.load(buffer)
-
- # Load a module with 'ascii' encoding for unpickling
- # 加载一个带有'ascii'编码的模块用于反pickle
- >>> torch.load('module.pt', encoding='ascii')