• Pytorch中张量矩阵乘法函数(mm, bmm, matmul)使用说明,含高维张量实例及运行结果


    1 torch.mm() 函数

    全称为matrix-matrix product,对输入的张量做矩阵乘法运算,输入输出维度一定是2维

    1.1 torch.mm() 函数定义及参数

    torch.bmm(input, mat2, , out=None) → Tensor
    input (Tensor) – – 第一个要相乘的矩阵
    ** mat2
    * (Tensor) – – 第二个要相乘的矩阵
    不支持广播到通用形状、类型推广以及整数、浮点和复杂输入。

    1.2 torch.bmm() 官方示例

    mat1 = torch.randn(2, 3)
    mat2 = torch.randn(3, 3)
    torch.mm(mat1, mat2)
    
    tensor([[ 0.4851,  0.5037, -0.3633],
            [-0.0760, -3.6705,  2.4784]])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    2 torch.bmm() 函数

    全称为batch matrix-matrix product,对输入的张量做矩阵乘法运算,输入输出维度一定是3维;

    2.1 torch.bmm() 函数定义及参数

    torch.bmm(input, mat2, , out=None) → Tensor
    input (Tensor) – – 第一批要相乘的矩阵
    ** mat2
    * (Tensor) – – 第二批要相乘的矩阵
    不支持广播到通用形状、类型推广以及整数、浮点和复杂输入。

    2.2 torch.bmm() 官方示例

    input = torch.randn(10, 3, 4)
    mat2 = torch.randn(10, 4, 5)
    res = torch.bmm(input, mat2)
    res.size()
    
    torch.Size([10, 3, 5])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    3 torch.matmul() 函数

    可进行多维矩阵运算,根据不同输入维度进行广播机制然后运算,和点积类似,广播机制可参考之前博文torch.mul()函数

    3.1 torch.matmul() 函数定义及参数

    torch.matmul(input, mat2, , out=None) → Tensor
    input (Tensor) – – 第一个要相乘的张量
    ** mat2
    * (Tensor) – – 第二个要相乘的张量
    支持广播到通用形状、类型推广以及整数、浮点和复杂输入。

    3.2 torch.matmul() 规则约定

    (1)若两个都是1D(向量)的,则返回两个向量的点积;

    (2)若两个都是2D(矩阵)的,则按照(矩阵相乘)规则返回2D;

    (3)若input维度1D,other维度2D,则先将1D的维度扩充到2D(1D的维数前面+1),然后得到结果后再将此维度去掉,得到的与input的维度相同。即使作扩充(广播)处理,input的维度也要和other维度做对应关系;

    (4)若input是2D,other是1D,则返回两者的点积结果;

    (5)如果一个维度至少是1D,另外一个大于2D,则返回的是一个批矩阵乘法( a batched matrix multiply)

    • (a)若input是1D,other是大于2D的,则类似于规则(3);
    • (b)若other是1D,input是大于2D的,则类似于规则(4);
    • (c)若input和other都是3D的,则与torch.bmm()函数功能一样;
    • (d)如果input中某一维度满足可以广播(扩充),那么也是可以进行相乘操作的。例如 input(j,1,n,m)* other (k,m,p) = output(j,k,n,p)

    matmul() 根据输入矩阵自动决定如何相乘。低维根据高维需求,合理广播。

    3.3 torch.matmul() 官方示例

    # vector x vector
    tensor1 = torch.randn(3)
    tensor2 = torch.randn(3)
    torch.matmul(tensor1, tensor2).size()
    
    torch.Size([])
    # matrix x vector
    tensor1 = torch.randn(3, 4)
    tensor2 = torch.randn(4)
    torch.matmul(tensor1, tensor2).size()
    
    torch.Size([3])
    # batched matrix x broadcasted vector
    tensor1 = torch.randn(10, 3, 4)
    tensor2 = torch.randn(4)
    torch.matmul(tensor1, tensor2).size()
    
    torch.Size([10, 3])
    # batched matrix x batched matrix
    tensor1 = torch.randn(10, 3, 4)
    tensor2 = torch.randn(10, 4, 5)
    torch.matmul(tensor1, tensor2).size()
    
    torch.Size([10, 3, 5])
    # batched matrix x broadcasted matrix
    tensor1 = torch.randn(10, 3, 4)
    tensor2 = torch.randn(4, 5)
    torch.matmul(tensor1, tensor2).size()
    
    torch.Size([10, 3, 5])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30

    3.4 高维数据实例解释

    直接看一个4维的二值例子,先看图(红虚线和实线是为了便于区分维度而添加),不懂再结合代码和结果分析,先做广播,然后对应矩阵进行乘积运算
    在这里插入图片描述

    代码如下:

    import torch
    import numpy as np
    
    np.random.seed(2022)
    a = np.random.randint(low=0, high=2, size=(2, 2, 3, 4))
    a = torch.tensor(a)
    b = np.random.randint(low=0, high=2, size=(2, 1, 4, 3))
    b = torch.tensor(b)
    c = torch.matmul(a, b)
    # or
    # c = a @ b
    print(a)
    print("=============================================")
    print(b)
    print("=============================================")
    print(c.size())
    print("=============================================")
    print(c)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18

    运行结果为:

    tensor([[[[1, 0, 1, 0],
              [1, 1, 0, 1],
              [0, 0, 0, 0]],
    
             [[1, 1, 1, 1],
              [1, 1, 0, 0],
              [0, 1, 0, 1]]],
    
    
            [[[0, 0, 0, 1],
              [0, 0, 0, 1],
              [0, 1, 0, 0]],
    
             [[1, 1, 1, 1],
              [1, 1, 1, 1],
              [0, 0, 0, 0]]]], dtype=torch.int32)
    =============================================
    tensor([[[[0, 1, 0],
              [1, 1, 0],
              [0, 0, 0],
              [1, 1, 0]]],
    
    
            [[[0, 1, 0],
              [1, 1, 1],
              [1, 1, 1],
              [1, 0, 1]]]], dtype=torch.int32)
    =============================================
    torch.Size([2, 2, 3, 3])
    =============================================
    tensor([[[[0, 1, 0],
              [2, 3, 0],
              [0, 0, 0]],
    
             [[2, 3, 0],
              [1, 2, 0],
              [2, 2, 0]]],
    
    
            [[[1, 0, 1],
              [1, 0, 1],
              [1, 1, 1]],
    
             [[3, 3, 3],
              [3, 3, 3],
              [0, 0, 0]]]], dtype=torch.int32)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46

    参考博文及感谢

    部分内容参考以下链接,这里表示感谢 Thanks♪(・ω・)ノ
    参考博文1 官方文档查询地址
    https://pytorch.org/docs/stable/index.html
    参考博文2 Pytorch矩阵乘法之torch.mul() 、 torch.mm() 及torch.matmul()的区别
    https://blog.csdn.net/irober/article/details/113686080

  • 相关阅读:
    2023山东老博会·CSOLDE中国国际养老服务业展览会
    简单查找操作
    centos7安装adb工具(拒绝抄袭)
    动态RDLC报表(五)
    Spring @Transactional事务管理
    liunx的基础命令整理
    第一章初识Maven与Maven安装配置——尚硅谷
    无线充电器出口欧盟CE认证RED指令测试分析
    关于Vue3中对于响应式API和组合式API的理解
    自然语言处理(NLP)—— 神经网络自然语言处理(2)实际应用
  • 原文地址:https://blog.csdn.net/qq_39407949/article/details/132890694