• 《动手学深度学习 Pytorch版》 5.3 延后初始化


    import torch
    from torch import nn
    from d2l import torch as d2l
    
    • 1
    • 2
    • 3

    下面实例化的多层感知机的输入维度是未知的,因此框架尚未初始化任何参数,显示为“UninitializedParameter”。

    net = nn.Sequential(nn.LazyLinear(256), nn.ReLU(), nn.LazyLinear(10))
    
    net[0].weight
    
    • 1
    • 2
    • 3
    c:\Software\Miniconda3\envs\d2l\lib\site-packages\torch\nn\modules\lazy.py:178: UserWarning: Lazy modules are a new feature under heavy development so changes to the API or functionality can happen at any moment.
      warnings.warn('Lazy modules are a new feature under heavy development '
    
    
    
    
    
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    一旦指定了输入维度,框架就可以一层一层的延迟初始化。

    X = torch.rand(2, 20)
    net(X)
    
    net[0].weight.shape
    
    • 1
    • 2
    • 3
    • 4
    torch.Size([256, 20])
    
    • 1

    练习

    (1)如果指定了第一层的输入维度,但没有指定后续层的维度,会发生什么?是否立即进行初始化?

    net = nn.Sequential(
        nn.Linear(20, 256), nn.ReLU(),
        nn.LazyLinear(128), nn.ReLU(),
        nn.LazyLinear(10)
    )
    net[0].weight, net[2].weight, net[4].weight
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    c:\Software\Miniconda3\envs\d2l\lib\site-packages\torch\nn\modules\lazy.py:178: UserWarning: Lazy modules are a new feature under heavy development so changes to the API or functionality can happen at any moment.
      warnings.warn('Lazy modules are a new feature under heavy development '
    
    
    
    
    
    (Parameter containing:
     tensor([[ 0.1332,  0.1372, -0.0939,  ..., -0.0579, -0.0911, -0.1820],
             [-0.1570, -0.0993, -0.0685,  ..., -0.0469, -0.0208,  0.0665],
             [ 0.0861,  0.1135,  0.1631,  ..., -0.1407,  0.1088, -0.2052],
             ...,
             [-0.1454, -0.0283, -0.1074,  ..., -0.2164, -0.2169,  0.1913],
             [-0.1617,  0.1206, -0.2119,  ..., -0.1862, -0.0951,  0.1535],
             [-0.0229, -0.2133, -0.1027,  ...,  0.1973,  0.1314,  0.1283]],
            requires_grad=True),
     ,
     )
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    net(X)  # 延迟初始化
    net[0].weight.shape, net[2].weight.shape, net[4].weight.shape
    
    • 1
    • 2
    (torch.Size([256, 20]), torch.Size([128, 256]), torch.Size([10, 128]))
    
    • 1

    (2)如果指定了不匹配的维度会发生什么?

    X = torch.rand(2, 10)
    net(X)  # 会报错
    
    • 1
    • 2
    ---------------------------------------------------------------------------
    
    RuntimeError                              Traceback (most recent call last)
    
    Cell In[14], line 2
          1 X = torch.rand(2, 10)
    ----> 2 net(X)
    
    
    File c:\Software\Miniconda3\envs\d2l\lib\site-packages\torch\nn\modules\module.py:1130, in Module._call_impl(self, *input, **kwargs)
       1126 # If we don't have any hooks, we want to skip the rest of the logic in
       1127 # this function, and just call forward.
       1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
       1129         or _global_forward_hooks or _global_forward_pre_hooks):
    -> 1130     return forward_call(*input, **kwargs)
       1131 # Do not call functions when jit is used
       1132 full_backward_hooks, non_full_backward_hooks = [], []
    
    
    File c:\Software\Miniconda3\envs\d2l\lib\site-packages\torch\nn\modules\container.py:139, in Sequential.forward(self, input)
        137 def forward(self, input):
        138     for module in self:
    --> 139         input = module(input)
        140     return input
    
    
    File c:\Software\Miniconda3\envs\d2l\lib\site-packages\torch\nn\modules\module.py:1130, in Module._call_impl(self, *input, **kwargs)
       1126 # If we don't have any hooks, we want to skip the rest of the logic in
       1127 # this function, and just call forward.
       1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
       1129         or _global_forward_hooks or _global_forward_pre_hooks):
    -> 1130     return forward_call(*input, **kwargs)
       1131 # Do not call functions when jit is used
       1132 full_backward_hooks, non_full_backward_hooks = [], []
    
    
    File c:\Software\Miniconda3\envs\d2l\lib\site-packages\torch\nn\modules\linear.py:114, in Linear.forward(self, input)
        113 def forward(self, input: Tensor) -> Tensor:
    --> 114     return F.linear(input, self.weight, self.bias)
    
    
    RuntimeError: mat1 and mat2 shapes cannot be multiplied (2x10 and 20x256)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42

    (3)如果输入具有不同的维度,需要做什么?

    调整维度,要么填充,要么降维。

  • 相关阅读:
    好心情:6种会加重抑郁症的食物,你却每天都在吃
    Android Studio六大布局详解
    【OpenAI Triton】理解矩阵乘法中的super-grouping 21a649eddf854db5ad4c7753afb7cb72
    人工智能第2版学习——盲目搜索1
    Vue入门(二)
    理解HTTP、HTTPS、TCP、UDP与OSI七层模型:网络访问的基础
    Spring项目bean 无法注入问题--Thread中注入Bean无效-多线程下@Resource和@Autowired和@Value 注入为null
    ssm医院人事管理系统设计与实现 毕业设计源码111151
    在Mac M2本地注册GitLab runner
    HTTP响应详解, HTTP请求构造及HTTPS详解
  • 原文地址:https://blog.csdn.net/qq_43941037/article/details/132914685