这个是层归一化。我们输入一个参数,这个参数就必须与最后一个维度对应。但是我们也可以输入多个维度,但是必须从后向前对应。
- import torch
- import torch.nn as nn
-
- a = torch.rand((100,5))
- c = nn.LayerNorm([5])
- print(c(a).shape)
-
- a = torch.rand((100,5,8,9))
- c = nn.LayerNorm([9])
- print(c(a).shape)
-
- a = torch.rand((100,5,8,9))
- c = nn.LayerNorm([8,9])
- print(c(a).shape)
-
- a = torch.rand((100,5,8,9))
- c = nn.LayerNorm([5,8,9])
- print(c(a).shape)
