• 矩阵相乘详解


    矩阵相乘详解

    已知三个矩阵 A , B , C A,B,C A,B,C

    在这里插入图片描述

    数学上的矩阵相乘 C = A × \times × B

    数学表示

    在这里插入图片描述

    程序表示

    多维矩阵:torch.matmul(A,B)

    if: A ∈ R n × m , B ∈ R m × n A\in R^{n\times m},B\in R^{m\times n} ARn×m,BRm×n

    then: torch.matmul(A, B) ∈ R n × n \in R^{n\times n} Rn×n

    二维矩阵相乘:torch.mm(A,B)

    # 矩阵相乘
    x = tensor([[1, 2, 3],
                [3, 3, 4],
                [3, 3, 3]])
    
    # torch.matmul表示矩阵的乘法
    torch.matmul(x,x)
    Out[1]: 
    tensor([[16, 17, 20],
            [24, 27, 33],
            [21, 24, 30]])
            
    # 两个维度对上就可以进行运算
    
    x = tensor([[1, 2, 3],
                [3, 3, 4],
                [3, 3, 3]])
                
    y = tensor([[1, 2],
                [3, 3],
                [4, 4]])
    torch.matmul(x, y)
    Out[2]: 
    tensor([[19, 20],
            [28, 31],
            [24, 27]])
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27

    数学上的矩阵对位相乘

    数学表示

    在这里插入图片描述

    程序表示

    torch.mul(A,B)

    
    # 表示矩阵对位相乘
    x = tensor([[1, 2, 3],
                [3, 3, 4],
                [3, 3, 3]])
    # 方法1
    x * x
    Out[3]: 
    tensor([[ 1,  4,  9],
            [ 9,  9, 16],
            [ 9,  9,  9]])
    
    # 方法2        
    torch.mul(x,x)
    Out[4]: 
    tensor([[ 1,  4,  9],
            [ 9,  9, 16],
            [ 9,  9,  9]])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18

    带有batch的三维就一阵相乘

    torch.bmm(A, B)

    A ∈ R B × n × m A\in R^{B\times n\times m} ARB×n×m, B ∈ R B × m × d B\in R^{B\times m\times d} BRB×m×d

    torch.bmm(A, B) ∈ R B × n × d \in R^{B\times n\times d} RB×n×d

    t = tensor([[[1, 2, 3],
                 [3, 3, 4],
                 [3, 3, 3]],
                 
                [[1, 2, 3],
                 [3, 3, 4],
                 [3, 3, 3]]])
    
    T = torch.bmm(t, t)
    T.shape
    
    Out[5]: torch.Size([2, 3, 3])
    
    T
    Out[6]: 
    tensor([[[16, 17, 20],
             [24, 27, 33],
             [21, 24, 30]],
            [[16, 17, 20],
             [24, 27, 33],
             [21, 24, 30]]])
             
             
    # 两个维度不同
    u = tensor([[[1, 2],
                 [3, 3],
                 [4, 4]],
                [[1, 2],
                 [3, 3],
                 [4, 4]]])
    t = tensor([[[1, 2, 3],
                 [3, 3, 4],
                 [3, 3, 3]],
                [[1, 2, 3],
                 [3, 3, 4],
                 [3, 3, 3]]])
                 
    u.shape
    Out[7]: torch.Size([2, 3, 2])
    t.shape
    Out[8]: torch.Size([2, 3, 3])
    
    torch.bmm(t, u)
    Out[9]: 
    tensor([[[19, 20],
             [28, 31],
             [24, 27]],
            [[19, 20],
             [28, 31],
             [24, 27]]])
    
    torch.bmm(t, u).shape
    Out[10]: torch.Size([2, 3, 2])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
  • 相关阅读:
    【填坑】乐鑫ESP32C3 Bootloader开发(上)
    流的基本概念以及常见应用
    C++——C++入门(二)
    我们为什么需要调用InitCommonControls?
    服务器数据恢复-阵列崩溃导致LVM结构破坏的数据恢复案例
    Linux进程管理2
    【SpringMVC】SpringMVC接受请求参数和数据回显
    我偷偷学了这5个命令,打印Linux环境变量那叫一个“丝滑”!
    Python学习笔记--对象的描述器
    Java自动化框架:jenkins执行git命令
  • 原文地址:https://blog.csdn.net/Jeaksun/article/details/126233437