• pytorch 实战【以图像处理为例】



    训练过程中保存模型

    在PyTorch中,模型训练过程中保存模型通常涉及以下几个步骤:

    1. 保存整个模型:
      使用 torch.save 函数,你可以保存整个模型,包括模型的结构和参数。

      torch.save(model, 'model.pth')
      
      • 1

      加载模型时,使用 torch.load 函数。

      model = torch.load('model.pth')
      
      • 1
    2. 保存模型的参数:
      这种方法通常更受欢迎,因为它只保存模型的参数,不保存模型的结构。这样,模型文件会比较小,并且在加载模型时可以更加灵活。

      torch.save(model.state_dict(), 'model_params.pth')
      
      • 1

      加载模型时,首先创建模型的实例,然后加载参数。

      model = ModelClass()  # replace ModelClass with your model's class name
      model.load_state_dict(torch.load('model_params.pth'))
      
      • 1
      • 2
    3. 保存训练的检查点:
      在训练过程中,除了保存模型或模型的参数,通常还会保存其他关键信息,例如优化器的状态、当前的epoch、最佳准确率等。这样,如果训练被中断,可以从检查点继续训练,而不是从头开始。

      checkpoint = {
          'epoch': epoch,
          'model_state_dict': model.state_dict(),
          'optimizer_state_dict': optimizer.state_dict(),
          'loss': loss,
          # ... any other relevant information
      }
      torch.save(checkpoint, 'checkpoint.pth')
      
      • 1
      • 2
      • 3
      • 4
      • 5
      • 6
      • 7
      • 8

      加载检查点时:

      checkpoint = torch.load('checkpoint.pth')
      model.load_state_dict(checkpoint['model_state_dict'])
      optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
      epoch = checkpoint['epoch']
      loss = checkpoint['loss']
      
      • 1
      • 2
      • 3
      • 4
      • 5
    4. 在训练时定期保存模型:
      通常,我们会在每个epoch结束时或在验证准确率提高时保存模型。这样,如果训练过程中出现任何问题,我们可以从最近的检查点恢复。

    • 保存检查点:

    在训练循环中,你可能会在每个 epoch 结束时或在模型在验证集上达到新的最佳性能时保存检查点:

    # 假设以下变量已经定义:
    # model: 你的模型
    # optimizer: 你使用的优化器
    # epoch: 当前的epoch
    # loss: 最近的loss值
    # best_accuracy: 迄今为止在验证集上的最佳准确率
    
    # 在每个 epoch 结束时或在验证准确率提高时:
    if current_accuracy > best_accuracy:  # current_accuracy是这个epoch在验证集上的准确率
        best_accuracy = current_accuracy
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            'best_accuracy': best_accuracy
        }
        torch.save(checkpoint, 'best_checkpoint.pth')
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 加载检查点:

    当你希望从检查点继续训练或评估模型时,可以使用以下代码来加载检查点:

    # 假设以下变量已经定义:
    # model: 你的模型 (需要先实例化)
    # optimizer: 你使用的优化器 (需要先实例化)
    
    checkpoint = torch.load('best_checkpoint.pth')
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    best_accuracy = checkpoint['best_accuracy']
    
    # 如果继续训练,可以从上一个 epoch 开始
    model.train()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13

    这样,即使训练过程中断,你也可以从上次停止的地方继续,而不是重新开始。

    1. 保存在不同设备上的模型:
      如果你在GPU上训练模型,但希望在CPU上加载模型,可以使用以下方式:
      torch.save(model.state_dict(), 'model_params.pth')
      # Loading on CPU
      model.load_state_dict(torch.load('model_params.pth', map_location=torch.device('cpu')))
      
      • 1
      • 2
      • 3

    总之,保存模型是训练深度学习模型的关键部分,它允许我们在训练中断时恢复,或在训练完成后部署模型。

    具体在训练中断如何继续

    如果训练过程中断并且你已经定期保存了检查点,那么你可以从最近的检查点恢复。以下是一个基本流程,描述如何在训练中断后从上次停止的地方继续:

    1. 加载检查点:

      在开始训练之前,首先加载保存的检查点。

      checkpoint = torch.load('best_checkpoint.pth')
      model.load_state_dict(checkpoint['model_state_dict'])
      optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
      start_epoch = checkpoint['epoch'] + 1
      best_accuracy = checkpoint.get('best_accuracy', -1)  # 默认为-1,假设你保存了这个值
      
      • 1
      • 2
      • 3
      • 4
      • 5
    2. 恢复训练:

      使用从检查点中加载的 start_epoch 作为起始点,并从那里开始你的训练循环。

      for epoch in range(start_epoch, total_epochs):
          # 训练代码...
          train_one_epoch()
      
          # 验证代码...
          current_accuracy = validate()
      
          # 保存新的检查点,如果模型在验证集上有更好的性能
          if current_accuracy > best_accuracy:
              best_accuracy = current_accuracy
              checkpoint = {
                  'epoch': epoch,
                  'model_state_dict': model.state_dict(),
                  'optimizer_state_dict': optimizer.state_dict(),
                  'best_accuracy': best_accuracy
                  # ... 你可以添加其他信息,如loss等
              }
              torch.save(checkpoint, 'best_checkpoint.pth')
      
      • 1
      • 2
      • 3
      • 4
      • 5
      • 6
      • 7
      • 8
      • 9
      • 10
      • 11
      • 12
      • 13
      • 14
      • 15
      • 16
      • 17
      • 18
    3. 注意点:

      • 学习率调整:如果你使用了学习率调度器,例如 ReduceLROnPlateauStepLR,那么你也应该保存和加载它的状态。这样可以确保学习率调整策略在中断后正确地继续。
      • 随机种子:为了确保训练的可复现性,如果你设置了随机种子,那么在恢复训练之前,你可能需要重新设置相同的随机种子。

    通过这种方式,你可以在训练中断后恢复并从上次停止的地方继续,而不会丢失任何进度。

  • 相关阅读:
    机械设计基础重点
    【Kingbase FlySync】命令模式:安装部署同步软件,实现KES到KES实现同步
    云原生下一步的发展方向
    计算机毕业设计SSM电影票网上订票系统【附源码数据库】
    App常用接口
    2022年《微信小程序从基础到uni-app项目实战》
    ubuntu 软件管理
    java计算机毕业设计华夏球迷俱乐部网站设计与实现源码+mysql数据库+系统+lw文档+部署
    Redis设计与实现-数据结构(建设进度17%)
    Three开关门
  • 原文地址:https://blog.csdn.net/weixin_42785537/article/details/133253931