• 使用torch普通算子组合替代torch.einsum爱因斯坦求和


    1. torch.einsum(‘bnd, bmd->bnm’, x, y)

    torch.einsum(‘bnd, bmd->bnm’, x, y) 表示的是对张量 x 和 y 进行特定的求和和维度变换。

    具体来说,这个操作的输入是两个形状为 [b, n, d] 和 [b, m, d] 的张量 x 和 y,输出是一个形状为 [b, n, m] 的张量 z。其计算过程可以理解为:对于每个 b,z[b, n, m] 等于 x[b, n, :] 和 y[b, m, :] 之间的点积。

    为了用普通的 torch 操作符来替代 einsum,我们可以通过 torch.matmul 函数实现。这个函数可以用来执行批量矩阵乘法,并且能够很好地替代这个 einsum 操作。

    具体实现如下:

    import torch
    
    # 假设 x 和 y 的形状分别为 (b, n, d) 和 (b, m, d)
    x = torch.randn(10, 20, 30)  # 举例
    y = torch.randn(10, 15, 30)  # 举例
    
    # einsum: z = torch.einsum('bnd, bmd->bnm', x, y)
    # 可以转换为以下操作:
    z = torch.matmul(x, y.transpose(-1, -2))  # z 的形状为 (b, n, m)
    
    # 检查 z 的形状是否正确
    print(z.shape)
    

    2. torch.einsum(‘ij,jk->ik’, A, B)

    可以用普通的矩阵乘法 torch.matmul 替代

    具体实现如下:

    import torch
    
    A = torch.rand(3, 4)
    B = torch.rand(4, 5)
    
    # 使用 einsum
    result_einsum = torch.einsum('ij,jk->ik', A, B)
    
    # 使用 matmul
    result_matmul = torch.matmul(A, B)
    
    # 验证结果相同
    print(torch.allclose(result_einsum, result_matmul))
    

    3. torch.einsum(‘bij,bjk->bik’, A, B)

    可以用 torch.bmm 来替代

    具体实现如下:

    import torch
    
    A = torch.rand(10, 3, 4)
    B = torch.rand(10, 4, 5)
    
    # 使用 einsum
    result_einsum = torch.einsum('bij,bjk->bik', A, B)
    
    # 使用 bmm
    result_bmm = torch.bmm(A, B)
    
    # 验证结果相同
    print(torch.allclose(result_einsum, result_bmm))
    

    4. torch.einsum(‘i,i->’, A, B)

    向量内积,可以用 torch.dot 来替代

    具体实现如下:

    import torch
    
    A = torch.rand(4)
    B = torch.rand(4)
    
    # 使用 einsum
    result_einsum = torch.einsum('i,i->', A, B)
    
    # 使用 dot
    result_dot = torch.dot(A, B)
    
    # 验证结果相同
    print(torch.allclose(result_einsum, result_dot))
    

    5. torch.einsum(‘i,j->ij’, A, B)

    向量外积,可以用 torch.outer 来替代

    具体实现如下:

    import torch
    
    A = torch.rand(4)
    B = torch.rand(5)
    
    # 使用 einsum
    result_einsum = torch.einsum('i,j->ij', A, B)
    
    # 使用 outer
    result_outer = torch.outer(A, B)
    
    # 验证结果相同
    print(torch.allclose(result_einsum, result_outer))
    

    不同的 einsum 表达式会对应不同的替代操作,有时可能需要组合多个普通操作来达到相同的效果。如果某些 einsum 表达式太复杂,使用普通算子替代时会比较繁琐,此时建议继续使用 einsum,因为它不仅更简洁,而且通常性能优化得很好。

    后续遇到其余需替换的 op 再进行更新

  • 相关阅读:
    Grandle安装配置使用
    他山之石,可以攻玉, 改造fasthttp实现高性能网络通信
    分布式消息通信之Kafka的实现原理
    仪酷LabVIEW OD实战(3)——Object Detection+onnx工具包快速实现yolo目标检测
    KMP算法
    前端如何优化工程
    [面试直通版]操作系统之锁、同步与通信(上)
    数据之道读书笔记-10未来已来:数据成为企业核心竞争力
    计算机毕设 LSTM的预测算法 - 股票预测 天气预测 房价预测
    微服务集成redis并通过redis实现排行榜的功能
  • 原文地址:https://blog.csdn.net/libo1004/article/details/140959776