• 【深度学习模型移植】用torch普通算子组合替代torch.einsum方法


         首先不得不佩服大模型的强大之处,在算法移植过程中遇到einsum算子在ONNX中不支持,因此需要使用普通算子替代。参考TensorRT - 使用torch普通算子组合替代torch.einsum爱因斯坦求和约定算子的一般性方法可以写出简单的替换方法,但是该方法会导致训练时还是推理都很慢,并且会消耗大量显存,造成显存溢出的问题。。因此采用提问文心一言,没想到居然真的回答正确了。当然替换需要验证,不是全对的。
    1.einsum(delta, A, ‘b l d_in, d_in n -> b l d_in n’) 的替换,以下两个方法均可以

    deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n'))
    deltaA = torch.exp(delta.unsqueeze(dim=3)*A.unsqueeze(dim=0).unsqueeze(dim=0))
    deltaA = torch.exp(delta.unsqueeze(-1).repeat_interleave(A.shape[1], dim=-1) * A)
    
    • 1
    • 2
    • 3

    2.einsum(x, C[:, i, :], ‘b d_in n, b n -> b d_in’),以下两个方法均可以

        
        y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in')
        y = (x*C[:, i, :].unsqueeze(dim=1)).sum(dim=2)
        y = torch.matmul(C[:, i, :], x.transpose(-1, -2)).squeeze(1)
    
    • 1
    • 2
    • 3
    • 4

    3.einsum(delta, B, u, ‘b l d_in, b l n, b l d_in -> b l d_in n’),以下两个方法均可以

    deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b l d_in n')
    deltaB_u1 = delta.unsqueeze(dim=3)*B.unsqueeze(dim=2)*u.unsqueeze(dim=3)
    
    • 1
    • 2

    下述方法是提问文心一言的办法,注意需要将答案的结果和einsum的结果进行对比,采用np.testing.assert_allclose(deltaB_u.numpy(),deltaB_u1.numpy(),rtol=1e-05,atol=1e-05)和print(deltaA.equal(deltaA_manual))均可以。

    import torch
    import numpy as np
    from einops import rearrange, repeat, einsum
    # 给定的张量
    delta = torch.ones([1, 3, 2])
    A = torch.ones([2, 4])
    deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n'))
    deltaA1 = torch.exp(delta.unsqueeze(dim=3)*A.unsqueeze(dim=0).unsqueeze(dim=0))
    deltaA_manual = torch.exp(delta.unsqueeze(-1).repeat_interleave(A.shape[1], dim=-1) * A)
    np.testing.assert_allclose(deltaA.numpy(),deltaA1.numpy(),rtol=1e-05,atol=1e-05)
    
    # 扩展 delta 的维度,以便它可以与 A 进行广播(broadcast)
    # 这里我们使用 unsqueeze 和 repeat_interleave 来扩展维度
    delta_expanded = delta.unsqueeze(-1).repeat_interleave(A.shape[1], dim=-1)
    # 执行逐元素的乘法,然后取指数
    deltaA_manual = torch.exp(delta_expanded * A)
    
    # 注意:deltaA_manual 的形状是 [1, 3, 2, 4],这与 einsum 的输出形状一致
    print(deltaA.equal(deltaA_manual))
    print(deltaA1.equal(deltaA_manual))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20

    请添加图片描述
    请添加图片描述
    请添加图片描述

  • 相关阅读:
    什么是原生IP?原生IP与住宅IP有何区别?
    flink on k8s
    设计模式—创建型模式之单例模式
    vim 窗口管理
    JavaScript进阶之路(一)初学者的开始
    【报错记录】spring boot 版本升级2.6.8 到之后的swagger3报错
    Python基本数据类型介绍
    Mysql中EXPLAIN解读
    【爬虫】Python使用动态IP,多线程,爬取uncomtrade的数据
    C:sprintf和snprintf的陷阱
  • 原文地址:https://blog.csdn.net/weixin_43509698/article/details/136753505