在PyTorch中,模型训练过程中保存模型通常涉及以下几个步骤:
保存整个模型:
使用 torch.save
函数,你可以保存整个模型,包括模型的结构和参数。
torch.save(model, 'model.pth')
加载模型时,使用 torch.load
函数。
model = torch.load('model.pth')
保存模型的参数:
这种方法通常更受欢迎,因为它只保存模型的参数,不保存模型的结构。这样,模型文件会比较小,并且在加载模型时可以更加灵活。
torch.save(model.state_dict(), 'model_params.pth')
加载模型时,首先创建模型的实例,然后加载参数。
model = ModelClass() # replace ModelClass with your model's class name
model.load_state_dict(torch.load('model_params.pth'))
保存训练的检查点:
在训练过程中,除了保存模型或模型的参数,通常还会保存其他关键信息,例如优化器的状态、当前的epoch、最佳准确率等。这样,如果训练被中断,可以从检查点继续训练,而不是从头开始。
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
# ... any other relevant information
}
torch.save(checkpoint, 'checkpoint.pth')
加载检查点时:
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
在训练时定期保存模型:
通常,我们会在每个epoch结束时或在验证准确率提高时保存模型。这样,如果训练过程中出现任何问题,我们可以从最近的检查点恢复。
在训练循环中,你可能会在每个 epoch 结束时或在模型在验证集上达到新的最佳性能时保存检查点:
# 假设以下变量已经定义:
# model: 你的模型
# optimizer: 你使用的优化器
# epoch: 当前的epoch
# loss: 最近的loss值
# best_accuracy: 迄今为止在验证集上的最佳准确率
# 在每个 epoch 结束时或在验证准确率提高时:
if current_accuracy > best_accuracy: # current_accuracy是这个epoch在验证集上的准确率
best_accuracy = current_accuracy
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'best_accuracy': best_accuracy
}
torch.save(checkpoint, 'best_checkpoint.pth')
当你希望从检查点继续训练或评估模型时,可以使用以下代码来加载检查点:
# 假设以下变量已经定义:
# model: 你的模型 (需要先实例化)
# optimizer: 你使用的优化器 (需要先实例化)
checkpoint = torch.load('best_checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
best_accuracy = checkpoint['best_accuracy']
# 如果继续训练,可以从上一个 epoch 开始
model.train()
这样,即使训练过程中断,你也可以从上次停止的地方继续,而不是重新开始。
torch.save(model.state_dict(), 'model_params.pth')
# Loading on CPU
model.load_state_dict(torch.load('model_params.pth', map_location=torch.device('cpu')))
总之,保存模型是训练深度学习模型的关键部分,它允许我们在训练中断时恢复,或在训练完成后部署模型。
如果训练过程中断并且你已经定期保存了检查点,那么你可以从最近的检查点恢复。以下是一个基本流程,描述如何在训练中断后从上次停止的地方继续:
加载检查点:
在开始训练之前,首先加载保存的检查点。
checkpoint = torch.load('best_checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch'] + 1
best_accuracy = checkpoint.get('best_accuracy', -1) # 默认为-1,假设你保存了这个值
恢复训练:
使用从检查点中加载的 start_epoch
作为起始点,并从那里开始你的训练循环。
for epoch in range(start_epoch, total_epochs):
# 训练代码...
train_one_epoch()
# 验证代码...
current_accuracy = validate()
# 保存新的检查点,如果模型在验证集上有更好的性能
if current_accuracy > best_accuracy:
best_accuracy = current_accuracy
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'best_accuracy': best_accuracy
# ... 你可以添加其他信息,如loss等
}
torch.save(checkpoint, 'best_checkpoint.pth')
注意点:
ReduceLROnPlateau
或 StepLR
,那么你也应该保存和加载它的状态。这样可以确保学习率调整策略在中断后正确地继续。通过这种方式,你可以在训练中断后恢复并从上次停止的地方继续,而不会丢失任何进度。