• PyTorch 中的乘法:mul()、multiply()、matmul()、mm()、mv()、dot()


    torch.mul()

    函数功能:逐个对 inputother 中对应的元素相乘。

    本操作支持广播,因此 inputother 均可以是张量或者数字。

    举例如下:

    >>> import torch
    >>> a = torch.randn(3)
    >>> a
    tensor([-1.7095,  1.7837,  1.1865])
    >>> b = 2
    >>> torch.mul(a, b)
    tensor([-3.4190,  3.5675,  2.3730])     # 这里将 other 扩展成了 input 的形状
    
    >>> a = 3
    >>> b = torch.randn(3, 1)
    >>> b
    tensor([[-0.7705],
            [ 1.1177],
            [ 1.2447]])
    >>> torch.mul(a, b)
    tensor([[-2.3116],
            [ 3.3530],
            [ 3.7341]])                     # 这里将 input 扩展成了 other 的形状
    
    >>> a = torch.tensor([[2], [3]])
    >>> a
    tensor([[2],
            [3]])                           # a 是 2×1 的张量
    >>> b = torch.tensor([-1, 2, 1])
    >>> b
    tensor([-1,  2,  1])                    # b 是 1×3 的张量
    >>> torch.mul(a, b)
    tensor([[-2,  4,  2],
            [-3,  6,  3]])
    

    这个例子中,inputoutput 的形状都不是公共形状,因此两个都需要广播,都变成 2×3 的形状,然后再逐个元素相乘。

    由上述例子可以看出,这种乘法是逐个对应元素相乘,因此 inputoutput 的前后顺序并不影响结果,即 torch.mul(a, b) =torch.mul(b, a)

    官方文档

    torch.multiply()

    torch.mul() 的别称。

    torch.dot()

    函数功能:计算 inputoutput 的点乘,此函数要求 inputoutput必须是一维的张量(其 shape 属性中只有一个值)!并且要求两者元素个数相同

    举例如下:

    >>> torch.dot(torch.tensor([2, 3]), torch.tensor([2, 1]))
    tensor(7)
    
    >>> torch.dot(torch.tensor([2, 3]), torch.tensor([2, 1, 1]))		# 要求两者元素个数相同
    Traceback (most recent call last):
      File "<stdin>", line 1, in <module>
    RuntimeError: inconsistent tensor size, expected tensor [2] and src [3] to have the same number of elements, but got 2 and 3 elements respectively
    

    官方文档

    torch.mm()

    函数功能:实现线性代数中的矩阵乘法(matrix multiplication):(n×m) × (m×p) = (n×p)

    本函数不允许广播!

    举例如下:

    >>> mat1 = torch.randn(2, 3)
    >>> mat2 = torch.randn(3, 2)
    >>> torch.mm(mat1, mat2)
    tensor([[-1.1846, -1.8327],
            [ 0.8820,  0.0312]])
    

    官方文档

    torch.mv()

    函数功能:实现矩阵和向量(matrix × vector)的乘法,要求 input 的形状为 n×moutputtorch.Size([m])的一维 tensor。

    举例如下:

    >>> mat = torch.tensor([[1, 2, 3], [4, 5, 6]])
    >>> mat
    tensor([[1, 2, 3],
            [4, 5, 6]])
    >>> vec = torch.tensor([-1, 1, 2])
    >>> vec
    tensor([-1,  1,  2])
    >>> mat.shape
    torch.Size([2, 3])
    >>> vec.shape
    torch.Size([3])
    >>> torch.mv(mat, vec)
    tensor([ 7, 13])
    

    注意,此函数要求第二个参数是一维 tensor,也即其 ndim 属性值为 1。这里我们要区分清楚张量的 shape 属性和 ndim 属性,前者表示张量的形状,后者表示张量的维度。(线性代数中二维矩阵的维度 m×n 通常理解为这里的形状)

    对于 shape 值为 torch.Size([n])torch.Size(1, n) 的张量,前者的 ndim=1 ,后者的 ndim=2 ,因此前者是可视为线代中的向量,后者可视为线代中的矩阵。

    对于 shape 值为 torch.Size([1, n])torch.Size([n, 1]) 的张量,它们同样在 Pytorch 中被视为矩阵。例如:

    >>> column = torch.tensor([[1], [2]])
    >>> row = torch.tensor([3, 4])
    >>> column.shape
    torch.Size([2, 1])				# 矩阵
    >>> row.shape
    torch.Size([2])					# 一维张量
    >>> matrix = torch.randn(1, 3)
    >>> matrix.shape
    torch.Size([1, 3])				# 矩阵
    

    对于张量(以及线代中的向量和矩阵)的理解可看这篇博文

    官方文档

    torch.bmm()

    函数功能:实现批量的矩阵乘法。

    本函数要求 inputoutputndim 均为 3,且前者形状为 b×n×m,后者形状为 b×m×p 。可以理解为 input 中包含 b 个形状为 n×m 的矩阵, output 中包含 b 个形状为 m×p 的矩阵,然后第一个 n×m 的矩阵 × 第一个 m×p 的矩阵得到第一个 n×p 的矩阵,第二个……,第 b 个……因此最终得到 b 个形状为 n×p 的矩阵,即最终结果是一个三维张量,形状为 b×n×p

    举例如下:

    >>> batch_matrix_1 = torch.tensor([ [[1, 2], [3, 4], [5, 6]] , [[-1, -2], [-3, -4], [-5, -6]] ])
    >>> batch_matrix_1
    tensor([[[ 1,  2],
             [ 3,  4],
             [ 5,  6]],
    
            [[-1, -2],
             [-3, -4],
             [-5, -6]]])
    >>> batch_matrix_1.shape
    torch.Size([2, 3, 2])
    
    >>> batch_matrix_2 = torch.tensor([ [[1, 2], [3, 4]], [[1, 2], [3, 4]] ])
    >>> bat
    batch_matrix_1 batch_matrix_2
    >>> batch_matrix_2
    tensor([[[1, 2],
             [3, 4]],
    
            [[1, 2],
             [3, 4]]])
    >>> batch_matrix_2.shape
    torch.Size([2, 2, 2])
    
    >>> torch.bmm(batch_matrix_1, batch_matrix_2)
    tensor([[[  7,  10],
             [ 15,  22],
             [ 23,  34]],
    
            [[ -7, -10],
             [-15, -22],
             [-23, -34]]])
    

    官方文档

    torch.matmul()

    torch.matmul() 可以用于 PyTorch 中绝大多数的乘法,在不同的情形下,它与上述各个乘法函数起着相同的作用,具体请看这篇博文

  • 相关阅读:
    运维-- 统一网关非常必要
    Android 开发者的跨平台 - Flutter or Compose ?
    TreeUtils工具类一行代码实现列表转树【第三版优化】 三级菜单 三级分类 附视频
    一个四位数,恰好等于去掉它的首位数字之后所剩的三位数的3倍,这个四位数是多少?
    1.2 w字+!Java IO 基础知识系统总结 | JavaGuide
    如何在 SOLIDWORKS中创建零件模板 硕迪科技
    JMeter界面字体大小设置方法
    HTML+CSS网页设计期末课程大作业 【茶叶文化网站设计题材】web前端开发技术 web课程设计 网页规划与设计
    Restriction (mathematics)
    什么是语句?什么是表达式?
  • 原文地址:https://www.cnblogs.com/HOMEofLowell/p/15962140.html