• pytorch 使用DataParallel 单机多卡和单卡保存和加载模型的正确方法


    1.单卡训练,单卡加载

    这里我为了把三个模块save到同一个文件里,我选择对所有的模型先封装成一个checkpoint字典,然后保存到同一个文件里,这样就可以在加载时只需要加载一个参数文件。

    保存:

    states = {
            'state_dict_encoder': encoder.state_dict(),
            'state_dict_decoder': decoder.state_dict(),
        }
    torch.save(states, fname)
    
    • 1
    • 2
    • 3
    • 4
    • 5

    加载:

    #先初始化模型,因为保存时只保存了模型参数,没有保存模型整个结构
    encoder = Encoder()
    decoder = Decoder()
    #然后加载参数
    checkpoint = torch.load(model_path) #model_path是你保存的模型文件的位置
    encoder_state_dict=checkpoint['state_dict_encoder']
    decoder_state_dict=checkpoint['state_dict_decoder']
    encoder.load_state_dict(encoder_state_dict)
    decoder.load_state_dict(decoder_state_dict)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    2.单卡训练,多卡加载

    保存:

    states = {
            'state_dict_encoder': encoder.state_dict(),
            'state_dict_decoder': decoder.state_dict(),
        }
    torch.save(states, fname)
    
    • 1
    • 2
    • 3
    • 4
    • 5

    加载:
    加载过程也没有任何改变,但是要注意,先加载模型参数,再对模型做并行化处理

    #先初始化模型,因为保存时只保存了模型参数,没有保存模型整个结构
    encoder = Encoder()
    decoder = Decoder()
    #然后加载参数
    checkpoint = torch.load(model_path) #model_path是你保存的模型文件的位置
    encoder_state_dict=checkpoint['state_dict_encoder']
    decoder_state_dict=checkpoint['state_dict_decoder']
    encoder.load_state_dict(encoder_state_dict)
    decoder.load_state_dict(decoder_state_dict)
    #并行处理模型
    encoder = nn.DataParallel(encoder)
    decoder = nn.DataParallel(decoder)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12

    3.多卡训练,单卡加载

    注意,如果你考虑到以后可能需要单卡加载你多卡训练的模型,建议在保存模型时,去除模型参数字典里面的module,如何去除呢,使用model.module.state_dict()代替model.state_dict()

    保存:

    states = {
            'state_dict_encoder': encoder.module.state_dict(), #不是encoder.state_dict()
            'state_dict_decoder': decoder.module.state_dict(),
        }
    torch.save(states, fname)
    
    • 1
    • 2
    • 3
    • 4
    • 5

    加载:
    要注意由于我们保存的方式是以单卡的方式保存的,所以还是要先加载模型参数,再对模型做并行化处理

    #先初始化模型,因为保存时只保存了模型参数,没有保存模型整个结构
    encoder = Encoder()
    decoder = Decoder()
    #然后加载参数
    checkpoint = torch.load(model_path) #model_path是你保存的模型文件的位置
    encoder_state_dict=checkpoint['state_dict_encoder']
    decoder_state_dict=checkpoint['state_dict_decoder']
    encoder.load_state_dict(encoder_state_dict)
    decoder.load_state_dict(decoder_state_dict)
    #并行处理模型
    encoder = nn.DataParallel(encoder)
    decoder = nn.DataParallel(decoder)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12

    3.多卡训练,单卡加载,方法二

    使用model.state_dict()保存,但是单卡加载的时候,要把模型做并行化(在单卡上并行)

    保存:

    states = {
            'state_dict_encoder': encoder.state_dict(), 
            'state_dict_decoder': decoder.state_dict(),
        }
    torch.save(states, fname)
    
    • 1
    • 2
    • 3
    • 4
    • 5

    加载:
    要注意由于我们保存的方式是以多卡的方式保存的,所以无论你加载之后的模型是在单卡运行还是在多卡运行,都先把模型并行化再去加载

    #先初始化模型,因为保存时只保存了模型参数,没有保存模型整个结构
    encoder = Encoder()
    decoder = Decoder()
    #并行处理模型
    encoder = nn.DataParallel(encoder)
    decoder = nn.DataParallel(decoder)
    #然后加载参数
    checkpoint = torch.load(model_path) #model_path是你保存的模型文件的位置
    encoder_state_dict=checkpoint['state_dict_encoder']
    decoder_state_dict=checkpoint['state_dict_decoder']
    encoder.load_state_dict(encoder_state_dict)
    decoder.load_state_dict(decoder_state_dict)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12

    4.多卡保存,多卡加载

    这就和多卡保存,单卡加载第二中方式一样了 使用model.state_dict()保存,加载的时候,要先把模型做并行化(在多卡上并行)

    保存:

    states = {
            'state_dict_encoder': encoder.state_dict(), 
            'state_dict_decoder': decoder.state_dict(),
        }
    torch.save(states, fname)
    
    • 1
    • 2
    • 3
    • 4
    • 5

    加载: 要注意由于我们保存的方式是以多卡的方式保存的,所以无论你加载之后的模型是在单卡运行还是在多卡运行,都先把模型并行化再去加载

    #先初始化模型,因为保存时只保存了模型参数,没有保存模型整个结构
    encoder = Encoder()
    decoder = Decoder()
    #并行处理模型
    encoder = nn.DataParallel(encoder)
    decoder = nn.DataParallel(decoder)
    #然后加载参数
    checkpoint = torch.load(model_path) #model_path是你保存的模型文件的位置
    encoder_state_dict=checkpoint['state_dict_encoder']
    decoder_state_dict=checkpoint['state_dict_decoder']
    encoder.load_state_dict(encoder_state_dict)
    decoder.load_state_dict(decoder_state_dict)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
  • 相关阅读:
    栈与队列:设计循环队列
    [极客大挑战 2019]BabySQL 1
    Dubbo学习(三)——dubbo实现负载均衡、智能容错功能
    spicy(一)基本定义
    LLM 面试总结
    Typora收费了, 还有哪些好用的markdown工具
    Day 40 Web容器-Tomcat
    正则表达式
    怒刷LeetCode的第21天(Java版)
    libevent库
  • 原文地址:https://blog.csdn.net/ZauberC/article/details/133179904