• pytorch 使用DataParallel 单机多卡和单卡保存和加载模型的正确方法


    1.单卡训练,单卡加载

    这里我为了把三个模块save到同一个文件里,我选择对所有的模型先封装成一个checkpoint字典,然后保存到同一个文件里,这样就可以在加载时只需要加载一个参数文件。

    保存:

    states = {
            'state_dict_encoder': encoder.state_dict(),
            'state_dict_decoder': decoder.state_dict(),
        }
    torch.save(states, fname)
    
    • 1
    • 2
    • 3
    • 4
    • 5

    加载:

    #先初始化模型,因为保存时只保存了模型参数,没有保存模型整个结构
    encoder = Encoder()
    decoder = Decoder()
    #然后加载参数
    checkpoint = torch.load(model_path) #model_path是你保存的模型文件的位置
    encoder_state_dict=checkpoint['state_dict_encoder']
    decoder_state_dict=checkpoint['state_dict_decoder']
    encoder.load_state_dict(encoder_state_dict)
    decoder.load_state_dict(decoder_state_dict)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    2.单卡训练,多卡加载

    保存:

    states = {
            'state_dict_encoder': encoder.state_dict(),
            'state_dict_decoder': decoder.state_dict(),
        }
    torch.save(states, fname)
    
    • 1
    • 2
    • 3
    • 4
    • 5

    加载:
    加载过程也没有任何改变,但是要注意,先加载模型参数,再对模型做并行化处理

    #先初始化模型,因为保存时只保存了模型参数,没有保存模型整个结构
    encoder = Encoder()
    decoder = Decoder()
    #然后加载参数
    checkpoint = torch.load(model_path) #model_path是你保存的模型文件的位置
    encoder_state_dict=checkpoint['state_dict_encoder']
    decoder_state_dict=checkpoint['state_dict_decoder']
    encoder.load_state_dict(encoder_state_dict)
    decoder.load_state_dict(decoder_state_dict)
    #并行处理模型
    encoder = nn.DataParallel(encoder)
    decoder = nn.DataParallel(decoder)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12

    3.多卡训练,单卡加载

    注意,如果你考虑到以后可能需要单卡加载你多卡训练的模型,建议在保存模型时,去除模型参数字典里面的module,如何去除呢,使用model.module.state_dict()代替model.state_dict()

    保存:

    states = {
            'state_dict_encoder': encoder.module.state_dict(), #不是encoder.state_dict()
            'state_dict_decoder': decoder.module.state_dict(),
        }
    torch.save(states, fname)
    
    • 1
    • 2
    • 3
    • 4
    • 5

    加载:
    要注意由于我们保存的方式是以单卡的方式保存的,所以还是要先加载模型参数,再对模型做并行化处理

    #先初始化模型,因为保存时只保存了模型参数,没有保存模型整个结构
    encoder = Encoder()
    decoder = Decoder()
    #然后加载参数
    checkpoint = torch.load(model_path) #model_path是你保存的模型文件的位置
    encoder_state_dict=checkpoint['state_dict_encoder']
    decoder_state_dict=checkpoint['state_dict_decoder']
    encoder.load_state_dict(encoder_state_dict)
    decoder.load_state_dict(decoder_state_dict)
    #并行处理模型
    encoder = nn.DataParallel(encoder)
    decoder = nn.DataParallel(decoder)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12

    3.多卡训练,单卡加载,方法二

    使用model.state_dict()保存,但是单卡加载的时候,要把模型做并行化(在单卡上并行)

    保存:

    states = {
            'state_dict_encoder': encoder.state_dict(), 
            'state_dict_decoder': decoder.state_dict(),
        }
    torch.save(states, fname)
    
    • 1
    • 2
    • 3
    • 4
    • 5

    加载:
    要注意由于我们保存的方式是以多卡的方式保存的,所以无论你加载之后的模型是在单卡运行还是在多卡运行,都先把模型并行化再去加载

    #先初始化模型,因为保存时只保存了模型参数,没有保存模型整个结构
    encoder = Encoder()
    decoder = Decoder()
    #并行处理模型
    encoder = nn.DataParallel(encoder)
    decoder = nn.DataParallel(decoder)
    #然后加载参数
    checkpoint = torch.load(model_path) #model_path是你保存的模型文件的位置
    encoder_state_dict=checkpoint['state_dict_encoder']
    decoder_state_dict=checkpoint['state_dict_decoder']
    encoder.load_state_dict(encoder_state_dict)
    decoder.load_state_dict(decoder_state_dict)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12

    4.多卡保存,多卡加载

    这就和多卡保存,单卡加载第二中方式一样了 使用model.state_dict()保存,加载的时候,要先把模型做并行化(在多卡上并行)

    保存:

    states = {
            'state_dict_encoder': encoder.state_dict(), 
            'state_dict_decoder': decoder.state_dict(),
        }
    torch.save(states, fname)
    
    • 1
    • 2
    • 3
    • 4
    • 5

    加载: 要注意由于我们保存的方式是以多卡的方式保存的,所以无论你加载之后的模型是在单卡运行还是在多卡运行,都先把模型并行化再去加载

    #先初始化模型,因为保存时只保存了模型参数,没有保存模型整个结构
    encoder = Encoder()
    decoder = Decoder()
    #并行处理模型
    encoder = nn.DataParallel(encoder)
    decoder = nn.DataParallel(decoder)
    #然后加载参数
    checkpoint = torch.load(model_path) #model_path是你保存的模型文件的位置
    encoder_state_dict=checkpoint['state_dict_encoder']
    decoder_state_dict=checkpoint['state_dict_decoder']
    encoder.load_state_dict(encoder_state_dict)
    decoder.load_state_dict(decoder_state_dict)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
  • 相关阅读:
    卡尔曼家族从零解剖-(06) 一维卡尔曼滤波编程(c++)实践、透彻理解公式结果
    材质技术在AI去衣中的作用
    7.Scala类
    【教3妹学算法】144. 二叉树的前序遍历
    MyBatis面试题(总结最全面的面试题)
    2022最新最详细必成功的在Vscode中设置背景图、同时解决不受支持的问题
    【Proteus仿真】Arduino UNO +74C922键盘解码驱动4X4矩阵键盘
    JAVA中小型医院信息管理系统源码 医院系统源码
    第二证券:科创板创业板涨跌幅限制?
    【2022-8-27完美世界】完美世界图像算法岗笔试
  • 原文地址:https://blog.csdn.net/ZauberC/article/details/133179904