基本用法:
torch.einsum(equation, *operands)
import torch
A = torch.randn(2, 3)
B = torch.einsum('ij->ji', A)
# 等价于 B = A.transpose(0, 1)
C = torch.einsum('ik,kj->ij', A, B)
# 等价于 C = torch.matmul(A, B)
a = torch.randn(3)
b = torch.randn(3)
c = torch.einsum('i,i->', a, b)
# 等价于 c = torch.dot(a, b)
A = torch.randn(5, 2, 3)
B = torch.randn(5, 3, 4)
C = torch.einsum('bij,bjk->bik', A, B)
# 等价于 C = torch.bmm(A, B)
a = torch.randn(3)
b = torch.randn(4)
c = torch.einsum('i,j->ij', a, b)
# 结果是一个3x4的矩阵,等价于 c = a.unsqueeze(1) * b.unsqueeze(0)
A = torch.randn(3, 3)
trace = torch.einsum('ii->', A)
# 等价于 trace = torch.trace(A)
全称为: batch matrix-matrix product, 批量矩阵乘法, 适用于三维张量,其中第一维表示批量大小,第二维和第三维表示矩阵的行和列
torch.bmm(input, mat2, *, out=None) -> Tensor
例如:
import torch
# 定义两个形状为 (b, n, m) 和 (b, m, p) 的三维张量
batch_size = 10
n, m, p = 3, 4, 5
A = torch.randn(batch_size, n, m)
B = torch.randn(batch_size, m, p)
# 进行批量矩阵乘法
C = torch.bmm(A, B)
print(C.shape) # 输出: torch.Size([10, 3, 5])
再具体的:
A = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
B = torch.tensor([[[9, 10], [11, 12]], [[13, 14], [15, 16]]])
# A.shape = (2, 2, 2)
# B.shape = (2, 2, 2)
C = torch.bmm(A, B)
print(C)
# 输出:
# tensor([[[ 31, 34],
# [ 73, 80]],
#
# [[155, 166],
# [211, 226]]])
其数学计算为: