• Pytorch中模型之间的参数共享


    神经网络模型训练时,有时候我们需要共享不同模型之间的网络参数,下面我将以一个案例展示一下如何共享模型训练参数。

    参数共享模块的模型结构必须完全一致才能实现参数共享

    一. 指定共享某一模块

    假设我们有以下两个模型:

    class ANN1(nn.Module):
        def __init__(self,features):
            super(ANN1, self).__init__()
            self.features = features
            self.nn_same = torch.nn.Sequential(
                nn.Linear(features, 128),
                torch.nn.ReLU(),
            )
            self.nn_diff = torch.nn.Sequential(
                nn.Linear(128, 1)
            )
    
        def forward(self, x):
            # x(batch_size, features)
            x = self.nn_same(x)
            x = self.nn_diff(x)
            return x
    class ANN2(nn.Module):
        def __init__(self,features):
            super(ANN2, self).__init__()
            self.features = features
            self.nn_same = torch.nn.Sequential(
                nn.Linear(features, 128),
                torch.nn.ReLU(),
            )
            self.nn_diff = torch.nn.Sequential(
                nn.Linear(128, 1)
            )
    
        def forward(self, x):
            # x(batch_size, features)
            x = self.nn_same(x)
            x = self.nn_diff(x)
            return x
        
    model1 = ANN1(10)
    model2 = ANN2(10)
    print(model1)
    print(model2)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    ANN1(
      (nn_same): Sequential(
        (0): Linear(in_features=10, out_features=128, bias=True)
        (1): ReLU()
      )
      (nn_diff): Sequential(
        (0): Linear(in_features=128, out_features=1, bias=True)
      )
    )
    ANN2(
      (nn_same): Sequential(
        (0): Linear(in_features=10, out_features=128, bias=True)
        (1): ReLU()
      )
      (nn_diff): Sequential(
        (0): Linear(in_features=128, out_features=1, bias=True)
      )
    )
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18

    其中 nn_same 代表要共享参数的模块,模块名称可以不相同,但是模块结构必须完全相同。

    因为模型初始化时参数是随机初始化的,所以两个模型的参数肯定不相同。假如我们要将 model1nn_same 模块的参数迁移到 model2 中的 nn_same 中,首先看一下 model1.nn_same 的参数:

    for param_tensor in model1.nn_same.state_dict():#输出迁移前的参数
        print(param_tensor, "\t", model1.nn_same.state_dict()[param_tensor])
    
    • 1
    • 2
    0.weight 	 tensor([[ 0.1321, -0.0178,  0.1631,  ..., -0.2531, -0.1584,  0.0588],
            [-0.2466, -0.0381,  0.2394,  ..., -0.2924, -0.1267, -0.1791],
            [-0.1713, -0.0716,  0.0598,  ...,  0.1655, -0.1947,  0.0927],
            ...,
            [-0.1795, -0.3082, -0.2846,  ...,  0.2588, -0.0998, -0.1285],
            [-0.2739, -0.1587,  0.1803,  ..., -0.1905, -0.2832, -0.0724],
            [ 0.1375, -0.1854, -0.1928,  ...,  0.1470,  0.2928,  0.1385]])
    0.bias 	 tensor([-0.2251, -0.3036,  0.2147, -0.0798, -0.1079, -0.0396, -0.1078,  0.1006,
            -0.1884, -0.0616,  0.0698,  0.0044,  0.1615, -0.2090,  0.0584, -0.0743,
             ···,
             0.3010, -0.1674,  0.0982,  0.2267, -0.0865, -0.1350, -0.2501,  0.1475,
             0.0187,  0.0819,  0.1840, -0.0988,  0.0133, -0.2082,  0.0376,  0.2993])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12

    下面我们进行参数迁移:

    print("****************迁移前*****************")
    for param_tensor in model2.nn_same.state_dict():#输出迁移前的参数
        print(param_tensor, "\t", model2.nn_same.state_dict()[param_tensor])
        
    model_nn_same = model1.nn_same.state_dict() ##获取model的nn_same部分的参数
    model2.nn_same.load_state_dict(model_nn_same,strict=True) #更新model2 nn_same部分的参数,#更新model2所有的参数,False表示跳过名称不同的层,True表示必须全部匹配(默认)
    
    print("****************迁移后*****************")
    for param_tensor in model2.nn_same.state_dict():#输出迁移后的参数
        print(param_tensor, "\t", model2.nn_same.state_dict()[param_tensor])
    #此时nn_same参数更新,nn_diff2参数不变
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    ****************迁移前*****************
    0.weight 	 tensor([[-0.1030, -0.0111,  0.0989,  ..., -0.3142, -0.0167,  0.0485],
            [ 0.1671,  0.2833,  0.1353,  ...,  0.1657, -0.2497, -0.1680],
            [ 0.0470,  0.1208,  0.1707,  ..., -0.0018,  0.2497,  0.0419],
            ...,
            [-0.2406, -0.2757,  0.2527,  ..., -0.0888, -0.2772,  0.1019],
            [-0.3035, -0.0227, -0.0194,  ...,  0.1280, -0.1167,  0.1060],
            [ 0.0565,  0.1870, -0.2729,  ..., -0.1215,  0.1343, -0.1057]])
    0.bias 	 tensor([ 0.0855,  0.3137,  0.2336, -0.2197,  0.0132, -0.1812, -0.1490, -0.1348,
             0.1027,  0.0284,  0.1064,  0.2046,  0.1106, -0.2034, -0.1283, -0.1561,
             ···,
             0.0328, -0.1035, -0.2942, -0.2368, -0.2290,  0.1846, -0.0270,  0.1286,
            -0.2331,  0.1111,  0.2172, -0.2865,  0.2086, -0.1388, -0.2077, -0.2976])
    ****************迁移后*****************
    0.weight 	 tensor([[ 0.1321, -0.0178,  0.1631,  ..., -0.2531, -0.1584,  0.0588],
            [-0.2466, -0.0381,  0.2394,  ..., -0.2924, -0.1267, -0.1791],
            [-0.1713, -0.0716,  0.0598,  ...,  0.1655, -0.1947,  0.0927],
            ...,
            [-0.1795, -0.3082, -0.2846,  ...,  0.2588, -0.0998, -0.1285],
            [-0.2739, -0.1587,  0.1803,  ..., -0.1905, -0.2832, -0.0724],
            [ 0.1375, -0.1854, -0.1928,  ...,  0.1470,  0.2928,  0.1385]])
    0.bias 	 tensor([-0.2251, -0.3036,  0.2147, -0.0798, -0.1079, -0.0396, -0.1078,  0.1006,
            -0.1884, -0.0616,  0.0698,  0.0044,  0.1615, -0.2090,  0.0584, -0.0743,
             ···,
             0.3010, -0.1674,  0.0982,  0.2267, -0.0865, -0.1350, -0.2501,  0.1475,
             0.0187,  0.0819,  0.1840, -0.0988,  0.0133, -0.2082,  0.0376,  0.2993])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26

    可以看到 model2nn_same 模块的参数已经与 model1nn_same 模块的参数一致。

    二. 共享所有相同名称的模块

    假设我们有以下两个模型:

    class ANN1(nn.Module):
        def __init__(self,features):
            super(ANN1, self).__init__()
            self.features = features
            self.nn_same1 = torch.nn.Sequential(
                nn.Linear(features, 128),
                torch.nn.ReLU(),
            )
            
            self.nn_same2 = torch.nn.Sequential(
                nn.Linear(features, 128),
                torch.nn.ReLU(),
            )
            
            self.nn_diff1 = torch.nn.Sequential(
                nn.Linear(128, 1)
            )
    
        def forward(self, x):
            # x(batch_size, features)
            x = self.nn_same(x)
            x = self.nn_diff(x)
            return x
    class ANN2(nn.Module):
        def __init__(self,features):
            super(ANN2, self).__init__()
            self.features = features
            self.nn_same1 = torch.nn.Sequential(
                nn.Linear(features, 128),
                torch.nn.ReLU(),
            )
            
            self.nn_same2 = torch.nn.Sequential(
                nn.Linear(features, 128),
                torch.nn.ReLU(),
            )
            
            self.nn_diff2 = torch.nn.Sequential(
                nn.Linear(128, 1)
            )
    
        def forward(self, x):
            # x(batch_size, features)
            x = self.nn_same(x)
            x = self.nn_diff(x)
            return x
        
    model1 = ANN1(10)
    model2 = ANN2(10)
    print(model1)
    print(model2)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    ANN1(
      (nn_same1): Sequential(
        (0): Linear(in_features=10, out_features=128, bias=True)
        (1): ReLU()
      )
      (nn_same2): Sequential(
        (0): Linear(in_features=10, out_features=128, bias=True)
        (1): ReLU()
      )
      (nn_diff1): Sequential(
        (0): Linear(in_features=128, out_features=1, bias=True)
      )
    )
    ANN2(
      (nn_same1): Sequential(
        (0): Linear(in_features=10, out_features=128, bias=True)
        (1): ReLU()
      )
      (nn_same2): Sequential(
        (0): Linear(in_features=10, out_features=128, bias=True)
        (1): ReLU()
      )
      (nn_diff2): Sequential(
        (0): Linear(in_features=128, out_features=1, bias=True)
      )
    )
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26

    假如我们要将 model1nn_same1nn_same2 模块的参数迁移到 model2 中的 nn_same1nn_same2 中:

    print("****************迁移前*****************")
    for param_tensor in model2.state_dict():#输出迁移前的参数
        print(param_tensor, "\t", model2.state_dict()[param_tensor])
        
    model_all = model1.state_dict() ##获取model的所有的参数
    model2.load_state_dict(model_all,strict=False) #更新model2所有的参数,False表示跳过名称不同的层,True表示必须全部匹配(默认)
    
    print("****************迁移后*****************")
    for param_tensor in model2.state_dict():#输出迁移后的参数
        print(param_tensor, "\t", model2.state_dict()[param_tensor])
    #此时nn_same参数更新,nn_diff2参数不变
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11

    其中需要注意的是在model2.load_state_dict(mode_all,strict=False)strict=False,表示两个模型的模块名不需要完全匹配,只会更新名称相同的模块。如果两个模型的模块名不完全相同但是strict=True那么就会报错:

    ---------------------------------------------------------------------------
    RuntimeError                              Traceback (most recent call last)
    <ipython-input-56-069ae53e28f3> in <module>
          4 
          5 model_all = model1.state_dict() ##获取model的所有的参数
    ----> 6 model2.load_state_dict(model_all,strict=True) #更新model2所有的参数,False表示跳过名称不同的层,True表示必须全部匹配(默认)
          7 
          8 print("****************迁移后*****************")
    
    D:\Anaconda3\envs\pytorch\lib\site-packages\torch\nn\modules\module.py in load_state_dict(self, state_dict, strict)
       1481         if len(error_msgs) > 0:
       1482             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
    -> 1483                                self.__class__.__name__, "\n\t".join(error_msgs)))
       1484         return _IncompatibleKeys(missing_keys, unexpected_keys)
       1485 
    
    RuntimeError: Error(s) in loading state_dict for ANN2:
    	Missing key(s) in state_dict: "nn_diff2.0.weight", "nn_diff2.0.bias". 
    	Unexpected key(s) in state_dict: "nn_diff1.0.weight", "nn_diff1.0.bias".
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
  • 相关阅读:
    基于python+PHP+MySQL的大学生二手闲置商品交易系统
    python爬取数据再次出错
    杰理之应用配置文件《app_config.c》介绍【篇】
    Java中异常的概念、体系结构和分类
    tensorflow卷积层操作
    使用sklearn进行机器学习案例(1)
    高等数学(第七版)同济大学 习题4-2(前半部分) 个人解答
    【三维世界】高性能图形渲染技术——Shader你又了解多少?
    milvus采坑一:启动服务就会挂掉
    【欧拉函数】CF1731E
  • 原文地址:https://blog.csdn.net/cyj972628089/article/details/127325735