当操作符是最最最自然的 *. 时,执行的时 element-wise 乘法,操作数会 broadcast。
更多细节请见Tensor unsqueeze 以 broadcast
就是执行矩阵乘法。
torch.mm(input, mat2, *, out=None) → Tensor
Performs a matrix multiplication of the matrices input and mat2.
If input is a (n
×
\times
× m) tensor, mat2 is a (m
×
\times
× p) tensor, out will be a (n
×
\times
× p)tensor.
import torch
Mat1 = torch.tensor([[1, 6, 7],
[2, 5, 8],
[3, 4, 9]])
Mat2 = torch.tensor([[9, 6, 3],
[8, 5, 2],
[7, 4, 1]])
Mat3 = torch.mm(Mat1, Mat2, out=None)
print(Mat3)
输出
tensor([[106, 64, 22],
[114, 69, 24],
[122, 74, 26]])
import torch
Mat1 = torch.tensor([[1, 3]])
print("Mat1's shape: ",Mat1.shape)
Mat2 = torch.tensor([[6, 4, 2],
[5, 3, 1]])
print("Mat2's shape: ",Mat2.shape)
Mat3 = torch.mm(Mat1, Mat2, out=None)
print(Mat3)
输出:
Mat1's shape: torch.Size([1, 2])
Mat2's shape: torch.Size([2, 3])
tensor([[21, 13, 5]])
将 Mat1 修改为
Mat1 = torch.tensor([1, 3])
输出:
Mat1's shape: torch.Size([2])
Mat2's shape: torch.Size([2, 3])
Mat3 = torch.mm(Mat1, Mat2, out=None)
RuntimeError: self must be a matrix
import torch
Mat1 = torch.tensor([[1, 3],
[2, 4]] )
Mat2 = torch.tensor([[6, 4, 2],
[5, 3, 1],
[7,8,9]])
Mat3 = torch.mm(Mat1, Mat2, out=None)
print(Mat3)
输出:
Mat3 = torch.mm(Mat1, Mat2, out=None)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (2x2 and 3x3)
import torch
Vec1 = torch.tensor([6,4,2])
print("Vec1's shape: ",Vec1.shape)
Vec2 = torch.tensor([5,3,1])
print("Vec2's shape: ",Vec2.shape)
Vec3 = torch.matmul(Vec1, Vec2)
print("Vec3: ",Vec3)
print("Vec3's shape: ",Vec3.shape,"\n")
输出
Vec1's shape: torch.Size([3])
Vec2's shape: torch.Size([3])
Vec3: tensor(44)
Vec3's shape: torch.Size([])
注意,Vec3 的 shape 是 torch.Size([]) 。
但是如果直接 print(torch.tensor([44]).shape)
会得到 torch.Size([1]) 而不是 torch.Size([])
Mat1 = torch.tensor([[1, 3]])
print("Mat1's shape: ",Mat1.shape)
Mat2 = torch.tensor([[6, 4, 2],
[5, 3, 1]])
print("Mat2's shape: ",Mat2.shape)
Mat3 = torch.matmul(Mat1, Mat2)
print("Mat3: ",Mat3)
print("Mat3's shape: ",Mat3.shape,"\n")
Mat4 = torch.matmul(Mat2, Mat1)
print("Mat4: ",Mat4)
print("Mat4's shape: ",Mat4.shape,"\n")
输出:
Mat1's shape: torch.Size([1, 2])
Mat2's shape: torch.Size([2, 3])
Mat3: tensor([[21, 13, 5]])
Mat3's shape: torch.Size([1, 3])
Traceback (most recent call last):
File "D:/Test2022.py", line 29, in <module>
Mat4 = torch.matmul(Mat2, Mat1)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (2x3 and 1x2)
注意,上面的 Mat3 结果和 torch.mm 计算出来的例 2 的结果一样的。
说明 2 维 tensor 与 2 维 tensor,torch.matmul 函数也是执行矩阵乘法。
注意 Mat4 的报错。
import torch
# first argument 1D and second argument 2D
mat1_1 = torch.tensor([3, 6, 2])
mat1_2 = torch.tensor([[1, 2, 3],
[4, 3, 8],
[1, 7, 2]])
out_1 = torch.matmul(mat1_1, mat1_2)
print("\n1D-2D matmul :\n", out_1)
# first argument 2D and second argument 1D
mat2_1 = torch.tensor([[1, 2, 3],
[4, 3, 8],
[1, 7, 2]])
mat2_2 = torch.tensor([3, 6, 2])
# assigning to output tensor
out_2 = torch.matmul(mat2_1, mat2_2)
print("\n2D-1D matmul :\n", out_2)
输出:
1D-2D matmul :
tensor([29, 38, 61])
2D-1D matmul :
tensor([21, 46, 49])
第一种情况1D-2D matmul 可以用
1
×
3
1\times3
1×3 和
3
×
3
3\times3
3×3 的矩阵乘法来理解。
第二种情况可以用
3
×
3
3\times3
3×3 和
3
×
1
3\times1
3×1 的矩阵乘法来理解。
import torch
# first argument 1D and second argument 2D
Mat1 = torch.tensor([[[1, 4,-9],
[2,-5,8],
[3,6,-7]],
[[2, 4, 6],
[1, 3, 5],
[7, 8, 9]]])
print("Mat1's shape: ",Mat1.shape)
Mat2 = torch.tensor([[[1, 2, 3],
[4, 3, 8],
[1, 7, 2]],
[[1, 7, 2],
[3, 2, 3],
[1, 1, 2]]])
print("Mat2's shape: ",Mat2.shape)
Out1 = torch.matmul(Mat1, Mat2)
print("\n3D-3D matmul :\n", Out1)
print("Out1's shape: ",Out1.shape)
输出:
Mat1's shape: torch.Size([2, 3, 3])
Mat2's shape: torch.Size([2, 3, 3])
3D-3D matmul :
tensor([[[ 8, -49, 17],
[-10, 45, -18],
[ 20, -25, 43]],
[[ 20, 28, 28],
[ 15, 18, 21],
[ 40, 74, 56]]])
Out1's shape: torch.Size([2, 3, 3])
Process finished with exit code 0
注,输入的两个 tensor 的 shape 都是
[
2
,
3
,
3
]
[2,3,3]
[2,3,3]
输出的 tensor 的shape 也是
[
2
,
3
,
3
]
[2,3,3]
[2,3,3]
实际上是 2 个
3
×
3
3\times3
3×3
的矩阵对应相乘,拼成一个
[
2
,
3
,
3
]
[2,3,3]
[2,3,3]
的输出