
- from torch import nn
- m = nn.Bilinear(96, 96, 96)
- input1 = torch.randn(8,7, 96)
- input2 = torch.randn(8,7, 96)
- output = m(input1, input2)
- print(output.size())
torch.Size([8, 7, 96])
参考资料
python - Understanding Bilinear Layers - Stack Overflow
pytorch中的nn.Bilinear的计算原理详解_nihate的博客-CSDN博客_pytorch中bilinear //讲得很细,非常推荐,如果像搞清楚计算原理的话