官网 pytorch linear:

Input: (*, H_{in})(∗,Hin) where *∗ means any number of dimensions including none and H_{in} = \text{in\_features}Hin=in_features. 任意维度 number 理解有歧义 (a) number. k可以理解三维,四维。。。 (b) 可以理解 为 某一维度的数 。
Output: (*, H_{out})(∗,Hout) where all but the last dimension are the same shape as the input and H_{out} = \text{out\_features}Hout=out_features.
分别 用三维 和二维输入数组,查看他们参数数目是否一样。
- import torch
-
- x = torch.randn(128, 20) # 输入的维度是(128,20)
- m = torch.nn.Linear(20, 30) # 20,30是指维度
- output = m(x)
- print('m.weight.shape:\n ', m.weight.shape)
- print('m.bias.shape:\n', m.bias.shape)
- print('output.shape:\n', output.shape)
-
- # ans = torch.mm(input,torch.t(m.weight))+m.bias 等价于下面的
- ans = torch.mm(x, m.weight.t()) + m.bias
- print('ans.shape:\n', ans.shape)
-
- print(torch.equal(ans, output))
output:
m.weight.shape: torch.Size([30, 20]) m.bias.shape: torch.Size([30]) output.shape: torch.Size([128, 30]) ans.shape: torch.Size([128, 30]) True
- x = torch.randn(128, 30,20) # 输入的维度是(128,30,20)
- m = torch.nn.Linear(20, 30) # 20,30是指维度
- output = m(x)
- print('m.weight.shape:\n ', m.weight.shape)
- print('m.bias.shape:\n', m.bias.shape)
- print('output.shape:\n', output.shape)
ouput: m.weight.shape: torch.Size([30, 20]) m.bias.shape: torch.Size([30]) output.shape: torch.Size([128, 30, 30])
(128,30,20),和 (128,20) 分别是如 nn.linear(30,20) 层。
weight.shape 均为: (30,20)
linear() 参数数目只和 input_dim ,output_dim 有关。
weight 在源码的定义, 没找到如何计算多维input的代码。
