• Pytorch学习笔记(9)——一文搞懂如何使用 torch 中的乘法


    网上关于 torch 的乘法文章也很多,但是也很凌乱,所以这里我自己整理了一份。
    本文的核心不是弄清楚 torch 是怎样实现的,源码如何,文档如何,本文只针对在什么情况下该调用怎样的方法。本文中只介绍了我使用过的方法,如果后续有新的方法就再进行添加。

    本文所有计算都以以下两个矩阵举例:
    a = [ 1 1 2 2 ] , b = [ 1 2 1 2 ] a = \left[

    1122" role="presentation" style="position: relative;">1122
    \right], b = \left[
    1212" role="presentation" style="position: relative;">1212
    \right] a=[1212],b=[1122]
    我们现在 torch 中创建这两个矩阵:

    # tensor([[1, 1],
    #         [2, 2]])
    a = torch.tensor([[1, 1], [2, 2]])
    # tensor([[1, 2],
    #         [1, 2]])
    b = torch.tensor([[1, 2], [1, 2]])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    1 乘法

    矩阵的乘法,从维度上来说就是 [ x × n ] ⋅ [ n × y ] = [ x × y ] [x \times n] \cdot [n \times y] = [x \times y] [x×n][n×y]=[x×y]。具体的计算方式可以自行翻阅线代的书或者课程,这里就不多赘述。我们先手算出来 a × b a \times b a×b 的结果如下:

    [ 1 1 2 2 ] ⋅ [ 1 2 1 2 ] = [ 2 4 4 8 ] \left[

    1122" role="presentation" style="position: relative;">1122
    \right] \cdot \left[
    1212" role="presentation" style="position: relative;">1212
    \right] = \left[
    2448" role="presentation" style="position: relative;">2448
    \right] [1212][1122]=[2448]

    注:

    1. 矩阵的乘法我们可以看做是多个向量的点积(dot product)。
    2. 向量的点积的公式可以用向量的模和夹角来计算,即 a ⋅ b = ∣ a ∣ ∣ b ∣ c o s ( θ ) a \cdot b = |a||b|{\rm cos}(\theta) ab=a∣∣bcos(θ),由此可以带出余弦相似度的公式: c o s ( θ ) = ( a ⋅ b ) / ( ∣ a ∣ ∣ b ∣ ) {\rm cos}(\theta) = (a \cdot b) / (|a||b|) cos(θ)=(ab)/(a∣∣b),所以点积在一定程度上可以体现出两个向量的相似程度,这点在注意力机制中很常见,比如 self-attention α = q ⋅ k T \alpha = \boldsymbol{q} \cdot \boldsymbol{k}^{\rm T} α=qkT

    Pytorch 中实现矩阵乘法的方法有以下几个:

    1.1 向量乘法

    向量的乘法即点积,我们可以用 torch 中的 dot 实现。以 a 的第一行 [ 1 , 1 ] [1, 1] [1,1]b 的第二列 [ 2 , 2 ] [2, 2] [2,2] 为例,手算出来结果为 4,用 torch 计算:

    c = torch.tensor([1, 1])
    d = torch.tensor([2, 2])
    torch.dot(c, d)
    # tensor(4)
    
    • 1
    • 2
    • 3
    • 4

    注: dot 仅能够计算向量,如果输入的维度大于 1 时会报错,以输入 torch.dot(a, b) 为例,报错如下:

    RuntimeError: 1D tensors expected, but got 2D and 2D tensors
    
    • 1

    1.2 矩阵乘法

    矩阵乘法在 torch 中使用 mm 实现:

    torch.mm(a, b)
    # tensor([[2, 4],
    #         [4, 8]])
    
    • 1
    • 2
    • 3

    与我们计算出来的结果一样。
    注: mm 仅能够计算矩阵,如果输入的维度不为2时会报错:

    RuntimeError: self must be a matrix
    
    • 1

    1.3 张量乘法

    torch 中的张量乘法有两类:bmmmatmul,区别如下:

    1.3.1 带 batch 的矩阵乘法

    bmm 中的 b 实际上是 batch 的意思,即 带 batch 的矩阵乘法。说明数据得是三维,且第一维为 batch 维,简单来说就是 batch 中的每个数据参与一次矩阵运算,用简单的伪码来说即:

    for i in batch:
    	a[i] * b[i]
    
    • 1
    • 2

    我们假设 batch 为 1,同时对 ab 升一维 batch 维,并使用 bmm 计算:

    # shape: (1, 2, 2)
    a = a.unsqueeze(0)
    b = b.unsqueeze(0)
    
    # shape: (1, 2, 2)
    torch.bmm(a, b)
    # tensor([[[2, 4],
    #          [4, 8]]])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    注: bmm 仅能计算三维张量,如果数据维度不为3,会报错:

    # 输入 2 维数据
    RuntimeError: Expected 3-dimensional tensor, but got 2-dimensional tensor for argument #1 'batch1' (while checking arguments for bmm)
    # 输入 4 维数据
    RuntimeError: Expected 3-dimensional tensor, but got 4-dimensional tensor for argument #1 'batch1' (while checking arguments for bmm)
    
    • 1
    • 2
    • 3
    • 4

    1.3.2 万能乘法

    matmul 算是 torch 中最万能的乘法,这个必须要结合 torch 的文档来说明:

    • If both tensors are 1-dimensional, the dot product (scalar) is returned.

    • If both arguments are 2-dimensional, the matrix-matrix product is returned.

    • If the first argument is 1-dimensional and the second argument is 2-dimensional, a 1 is prepended to its dimension for the purpose of the matrix multiply. After the matrix multiply, the prepended dimension is removed.

    • If the first argument is 2-dimensional and the second argument is 1-dimensional, the matrix-vector product is returned.

    • If both arguments are at least 1-dimensional and at least one argument is N-dimensional (where N > 2), then a batched matrix multiply is returned. If the first argument is 1-dimensional, a 1 is prepended to its dimension for the purpose of the batched matrix multiply and removed after. If the second argument is 1-dimensional, a 1 is appended to its dimension for the purpose of the batched matrix multiple and removed after. The non-matrix (i.e. batch) dimensions are broadcasted (and thus must be broadcastable). For example, if input is a ( j × 1 × n × n ) (j \times 1 \times n \times n) (j×1×n×n) tensor and other is a ( k × n × n ) (k \times n \times n) (k×n×n) tensor, out will be a ( j × k × n × n ) (j \times k \times n \times n) (j×k×n×n) tensor.

    • Note that the broadcasting logic only looks at the batch dimensions when determining if the inputs are broadcastable, and not the matrix dimensions. For example, if input is a ( j × 1 × n × m ) (j \times 1 \times n \times m) (j×1×n×m) tensor and other is a ( k × m × p ) (k \times m \times p) (k×m×p) tensor, these inputs are valid for broadcasting even though the final two dimensions (i.e. the matrix dimensions) are different. out will be a ( j × k × n × p ) (j \times k \times n \times p) (j×k×n×p) tensor.

    总体而言,matmul 执行的还是矩阵乘法,只是会自动填补维度信息。同时,其计算的是 最后两维 的数据。用代码执行一次:

    # 向量
    # 这里是 a = [1, 1], b = [2, 2]
    # shape: (), 是的, 这里是标量, 所以没有维度
    torch.matmul(a, b)
    # tensor(4)
    
    # 矩阵
    # shape: (2, 2)
    torch.matmul(a, b)
    # tensor([[2, 4],
    #         [4, 8]])
    
    # 三维张量
    # shape: (1, 2, 2)
    torch.matmul(a, b)
    # tensor([[[2, 4],
    #          [4, 8]]])
    
    # 四维张量
    # shape: (1, 1, 2, 2)
    torch.matmul(a, b)
    # tensor([[[[2, 4],
    #           [4, 8]]]])
    
    # 五维张量
    # shape: (1, 1, 1, 2, 2)
    torch.matmul(a, b)
    # tensor([[[[[2, 4],
    #            [4, 8]]]]])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29

    2 对位相乘

    对位相乘(element-wise product)指的是两个矩阵中第 i i i 行,第 j j j 列的元素直接相乘。以 ab 为例,手算得到(这里 ⊗ \otimes 指对位相乘):
    [ 1 1 2 2 ] ⊗ [ 1 2 1 2 ] = [ 1 2 2 4 ] \left[

    1122" role="presentation" style="position: relative;">1122
    \right] \otimes \left[
    1212" role="presentation" style="position: relative;">1212
    \right] = \left[
    1224" role="presentation" style="position: relative;">1224
    \right] [1212][1122]=[1224]

    2.1 直接乘法

    torch 中可以直接使用 * 实现对位相乘:

    # 一维
    # 这里是 a = [1, 1], b = [2, 2]
    # shape: (2)
    a * b
    # tensor([2, 2])
    
    # 二维
    # shape: (2, 2)
    a * b
    # tensor([[1, 2],
    #         [2, 4]])
    
    # 三维
    # shape: (1, 2, 2)
    a * b
    # tensor([[[1, 2],
    #          [2, 4]]])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17

    2.2 调库实现

    当然,torch 中也可以通过调库来实现对位相乘,即 mul

    # 一维
    # 这里是 a = [1, 1], b = [2, 2]
    # shape: (2)
    torch.mul(a, b)
    # tensor([2, 2])
    
    # 二维
    # shape: (2, 2)
    torch.mul(a, b)
    # tensor([[1, 2],
    #         [2, 4]])
    
    # 三维
    # shape: (1, 2, 2)
    torch.mul(a, b)
    # tensor([[[1, 2],
    #          [2, 4]]])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
  • 相关阅读:
    修改yum源为国内yum源和本地yum源
    ZooKeeper核心知识总结!
    Git协同开发
    cola架构:一种扩展点的实现思路浅析
    关于网络协议的若干问题(三)
    Navicat15工具连接PostgreSQL15失败
    应用在液晶背光领域中的环境光传感芯片
    【LeetCode75】第七十一题 搜索推荐系统
    【MySQL--->索引】
    原理Redis-动态字符串SDS
  • 原文地址:https://blog.csdn.net/qq_35357274/article/details/126585802