• 【python深度学习】——torch.einsum|torch.bmm


    【python深度学习】——torch.einsum|torch.bmm

    1. 基本用法与示例

    基本用法:

    torch.einsum(equation, *operands)
    
    • equation: 一个字符串,定义了张量操作的模式。
      使用逗号来分隔输入张量的索引,然后是一个箭头(->),接着是输出张量的索引
    • operands: 要操作的张量。
      示例代码:
    import torch
    A = torch.randn(2, 3)
    
    B = torch.einsum('ij->ji', A)
    # 等价于 B = A.transpose(0, 1)
    
    C = torch.einsum('ik,kj->ij', A, B)
    # 等价于 C = torch.matmul(A, B)
    
    a = torch.randn(3)
    b = torch.randn(3)
    c = torch.einsum('i,i->', a, b)
    # 等价于 c = torch.dot(a, b)
    
    
    A = torch.randn(5, 2, 3)
    B = torch.randn(5, 3, 4)
    C = torch.einsum('bij,bjk->bik', A, B)
    # 等价于 C = torch.bmm(A, B)
    
    
    a = torch.randn(3)
    b = torch.randn(4)
    c = torch.einsum('i,j->ij', a, b)
    # 结果是一个3x4的矩阵,等价于 c = a.unsqueeze(1) * b.unsqueeze(0)
    
    
    A = torch.randn(3, 3)
    trace = torch.einsum('ii->', A)
    # 等价于 trace = torch.trace(A)
    
    
    

    2. torch.bmm

    全称为: batch matrix-matrix product, 批量矩阵乘法, 适用于三维张量,其中第一维表示批量大小,第二维和第三维表示矩阵的行和列

    torch.bmm(input, mat2, *, out=None) -> Tensor
    
    • input: 一个形状为 (b, n, m) 的三维张量,表示一批矩阵。
    • mat2: 一个形状为 (b, m, p) 的三维张量,表示另一批矩阵。
    • out (可选): 存储输出结果的张量。
      输出是一个形状为 (b, n, p) 的张量,其中每个矩阵是对应批次的矩阵乘法结果。

    例如:

    import torch
    
    # 定义两个形状为 (b, n, m) 和 (b, m, p) 的三维张量
    batch_size = 10
    n, m, p = 3, 4, 5
    
    A = torch.randn(batch_size, n, m)
    B = torch.randn(batch_size, m, p)
    
    # 进行批量矩阵乘法
    C = torch.bmm(A, B)
    
    print(C.shape)  # 输出: torch.Size([10, 3, 5])
    
    

    再具体的:

    A = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
    B = torch.tensor([[[9, 10], [11, 12]], [[13, 14], [15, 16]]])
    
    # A.shape = (2, 2, 2)
    # B.shape = (2, 2, 2)
    C = torch.bmm(A, B)
    
    print(C)
    # 输出:
    # tensor([[[ 31,  34],
    #          [ 73,  80]],
    #
    #         [[155, 166],
    #          [211, 226]]])
    
    

    数学计算为:
    请添加图片描述

  • 相关阅读:
    虾皮插件能做数据分析的-知虾数据分析插件Shopee大数据分析平台
    【第2章 Node.js基础】2.4 Node.js 全局对象(二) process 对象
    基于SpringBoot的二手商品交易平台
    Linux安装jdk的详细步骤
    c++11 智能指针 (std::shared_ptr)(五)
    python性能分析
    国学---佛系算吉凶~
    webpack 面试题
    题解——二维费用背包问题(宠物小精灵之收服、潜水员)
    字符串算法
  • 原文地址:https://blog.csdn.net/steptoward/article/details/139471205