码农知识堂 - 1000bd
  •   Python
  •   PHP
  •   JS/TS
  •   JAVA
  •   C/C++
  •   C#
  •   GO
  •   Kotlin
  •   Swift
  • 【深度学习模型移植】用torch普通算子组合替代torch.einsum方法


         首先不得不佩服大模型的强大之处,在算法移植过程中遇到einsum算子在ONNX中不支持,因此需要使用普通算子替代。参考TensorRT - 使用torch普通算子组合替代torch.einsum爱因斯坦求和约定算子的一般性方法。可以写出简单的替换方法,但是该方法会导致训练时还是推理都很慢,并且会消耗大量显存,造成显存溢出的问题。。因此采用提问文心一言,没想到居然真的回答正确了。当然替换需要验证,不是全对的。
    1.einsum(delta, A, ‘b l d_in, d_in n -> b l d_in n’) 的替换,以下两个方法均可以

    deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n'))
    deltaA = torch.exp(delta.unsqueeze(dim=3)*A.unsqueeze(dim=0).unsqueeze(dim=0))
    deltaA = torch.exp(delta.unsqueeze(-1).repeat_interleave(A.shape[1], dim=-1) * A)
    
    • 1
    • 2
    • 3

    2.einsum(x, C[:, i, :], ‘b d_in n, b n -> b d_in’),以下两个方法均可以

        
        y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in')
        y = (x*C[:, i, :].unsqueeze(dim=1)).sum(dim=2)
        y = torch.matmul(C[:, i, :], x.transpose(-1, -2)).squeeze(1)
    
    • 1
    • 2
    • 3
    • 4

    3.einsum(delta, B, u, ‘b l d_in, b l n, b l d_in -> b l d_in n’),以下两个方法均可以

    deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b l d_in n')
    deltaB_u1 = delta.unsqueeze(dim=3)*B.unsqueeze(dim=2)*u.unsqueeze(dim=3)
    
    • 1
    • 2

    下述方法是提问文心一言的办法,注意需要将答案的结果和einsum的结果进行对比,采用np.testing.assert_allclose(deltaB_u.numpy(),deltaB_u1.numpy(),rtol=1e-05,atol=1e-05)和print(deltaA.equal(deltaA_manual))均可以。

    import torch
    import numpy as np
    from einops import rearrange, repeat, einsum
    # 给定的张量
    delta = torch.ones([1, 3, 2])
    A = torch.ones([2, 4])
    deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n'))
    deltaA1 = torch.exp(delta.unsqueeze(dim=3)*A.unsqueeze(dim=0).unsqueeze(dim=0))
    deltaA_manual = torch.exp(delta.unsqueeze(-1).repeat_interleave(A.shape[1], dim=-1) * A)
    np.testing.assert_allclose(deltaA.numpy(),deltaA1.numpy(),rtol=1e-05,atol=1e-05)
    
    # 扩展 delta 的维度,以便它可以与 A 进行广播(broadcast)
    # 这里我们使用 unsqueeze 和 repeat_interleave 来扩展维度
    delta_expanded = delta.unsqueeze(-1).repeat_interleave(A.shape[1], dim=-1)
    # 执行逐元素的乘法,然后取指数
    deltaA_manual = torch.exp(delta_expanded * A)
    
    # 注意:deltaA_manual 的形状是 [1, 3, 2, 4],这与 einsum 的输出形状一致
    print(deltaA.equal(deltaA_manual))
    print(deltaA1.equal(deltaA_manual))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20

    请添加图片描述
    请添加图片描述
    请添加图片描述

  • 相关阅读:
    python 基于django医院预约挂号管理系统
    局部线性嵌入LLE算法--学习笔记
    PacBio三代宏基因组测序大幅提升海洋水体宏基因组研究效果
    webpack:关于处理html文件的插件html-webpack-plugin、add-asset-html-webpack-plugin
    MATLAB使用OMP实现图像的压缩感知实例
    vscode的git 工具使用
    【Android】画面卡顿优化列表流畅度三之RecyclerView刷新机制notifyItemRangeInserted
    多种DNS的详细搭建方案和实现步骤,自建DNS防止DNS污染、DNS劫持,包括基于bind的named、dnsmasq、c-l-a-s-hdns等。
    java毕业生设计药品管理系统演示录像2021计算机源码+系统+mysql+调试部署+lw
    Python中import出现路径错误总结
  • 原文地址:https://blog.csdn.net/weixin_43509698/article/details/136753505
  • 最新文章
  • 攻防演习之三天拿下官网站群
    数据安全治理学习——前期安全规划和安全管理体系建设
    企业安全 | 企业内一次钓鱼演练准备过程
    内网渗透测试 | Kerberos协议及其部分攻击手法
    0day的产生 | 不懂代码的"代码审计"
    安装scrcpy-client模块av模块异常,环境问题解决方案
    leetcode hot100【LeetCode 279. 完全平方数】java实现
    OpenWrt下安装Mosquitto
    AnatoMask论文汇总
    【AI日记】24.11.01 LangChain、openai api和github copilot
  • 热门文章
  • 十款代码表白小特效 一个比一个浪漫 赶紧收藏起来吧!!!
    奉劝各位学弟学妹们,该打造你的技术影响力了!
    五年了,我在 CSDN 的两个一百万。
    Java俄罗斯方块,老程序员花了一个周末,连接中学年代!
    面试官都震惊,你这网络基础可以啊!
    你真的会用百度吗?我不信 — 那些不为人知的搜索引擎语法
    心情不好的时候,用 Python 画棵樱花树送给自己吧
    通宵一晚做出来的一款类似CS的第一人称射击游戏Demo!原来做游戏也不是很难,连憨憨学妹都学会了!
    13 万字 C 语言从入门到精通保姆级教程2021 年版
    10行代码集2000张美女图,Python爬虫120例,再上征途
Copyright © 2022 侵权请联系2656653265@qq.com    京ICP备2022015340号-1
正则表达式工具 cron表达式工具 密码生成工具

京公网安备 11010502049817号