• MindSpore尝鲜之Vmap功能


    技术背景

    Vmap是一种在python里面经常提到的向量化运算的功能,比如之前大家常用的就是numba和jax中的向量化运算的接口。虽然numpy中也使用到了向量化的运算,比如计算两个numpy数组的加和,就是一种向量化的运算。但是在numpy中模块封装的较好,定制化程度低,但是使用便捷,只需要调用最上层的接口即可。现在最新版本的mindspore也已经推出了vmap的功能,像mindspore、numba还有jax,与numpy的最大区别就是,需要在使用过程中对需要向量化运算的函数额外嵌套一层vmap的函数,这样就可以实现只对需要向量化运算的模块进行扩展。用一个公式来理解向量化运算的话就是:

    a1+b1=c1a2+b2=c2...an+bn=cna+b=c
    a1+b1=c1a2+b2=c2...an+bn=cna⃗ +b⃗ =c⃗ 

    安装最新版MindSpore

    关于jax中的vmap使用案例,可以参考前面介绍的LINCS约束算法实现SETTLE约束算法批量化实现这两篇文章,都有使用到jax的vmap功能,这里我们着重介绍的是MindSpore中最新实现的vmap功能。首先我们需要安装mindspore最新的Nightly版本,其对应的是MindSpore的Gitee仓库中的master分支,具体安装指令可以参考其官方链接

    因为我们本地已经安装过Mindspore的旧版本,因此还需要在安装指令之后加上--upgrade操作,否则会导致系统误以为本地已经安装成功,不会执行安装的操作:

    $ python3 -m pip install mindspore-cuda11-dev -i https://pypi.tuna.tsinghua.edu.cn/simple --upgrade
    

    Vmap功能测试

    这里我们先来看一个比较简单的示例:

    In [1]: from mindspore import Tensor
    
    In [2]: from mindspore.ops.functional import vmap
    
    In [3]: y = lambda a,b: a+b
    
    In [4]: A = Tensor([1,2,3])
    
    In [5]: B = Tensor([3,4,5])
    
    In [6]: vmap_y = vmap(y,in_axes=(0,0))
    
    In [7]: y(A[0],B[0]) # 元素加和
    Out[7]: Tensor(shape=[], dtype=Int64, value= 4)
    
    In [8]: vmap_y(A,B) # 矢量加和
    Out[8]: Tensor(shape=[3], dtype=Int64, value= [4, 6, 8])
    

    在上面的这个示例中,我们定义了一个加法函数y,作用就是把输入的两个对象相加。这里需要注意的是,如果输入给y的是两个Mindspore的Tensor对象,那么会直接返回两个Tensor对应位置相加的结果。但是如果输入给y的是两个普通python的list,则输出的结果会是两个list的拼接,这跟不同类型的加法的实现方式有关,在文末总结中会进行解释。这里我们只是想说明:y本身是一个元素加和的函数,可以通过vmap使其称为矢量加和的函数。关于输入的in_axes参数,指的是扩展的维度。比如我们写了一个支持(A,A)×(A,1)(A,A)×(A,1)维度的函数,如果把in_axes参数设置为0,那么就可以得到一个支持计算(B,A,A)×(B,A,1)(B,A,A)×(B,A,1)维度的函数。其中in_axes参数,决定的是被扩展的维度B所在的位置。这一点我们可以看一下vmap的官方示例:

    在这个案例中,也是定义了一个普通的加和函数,通过vmap去扩展不同的维度,大致的计算逻辑为:

    (A,)+(A,)+(A,)in_axes=(0,1,None)(B,A)+(A,B)+(A,)=(B,A)+(B,A)+(1,A)=(B,A)out_axes=1(A,B)
    (A,)+(A,)+(A,)in_axes=(0,1,None)(B,A)+(A,B)+(A,)=(B,A)+(B,A)+(1,A)=(B,A)out_axes=1(A,B)

    其实这个过程中关于in_axes是比较容易可以理解的,但是这个out_axes有时候会让人难以捉摸,在github上专门有人提出了这个issue并有人做出了解释:

    结合上面的案例,其实out_axes就是决定了扩展的维度B在结果中的位置,比如out_axes=1,所对应的结果中就是(x,B,x,...x)(x,B,x,...x)。也就是说,其不影响计算的结果,但是有可能会对计算结果进行转置操作,在MindSpore和Numpy中称为swap_axes

    总结概要

    本文介绍了华为推出的深度学习框架MindSpore中最新支持的vmap功能函数,可以用于向量化的计算,本质上的主要作用是替代并加速python中的for循环的操作。最早是在numba和pytroch、jax中对vmap功能进行了支持,其实numpy中的底层计算也用到了向量化的运算,因此速度才如此之快。vmap在python中更多的是与即时编译功能jit一同使用,能够起到简化编程的同时对性能进行极大程度的优化,尤其是python中的for循环的优化。但是对于一些numpy、jax或者MindSpore中已有的算子而言,还是建议直接使用其已经实现的算子,而不是vmap再手写一个。

    版权声明

    本文首发链接为:https://www.cnblogs.com/dechinphy/p/ms-vmap.html

    作者ID:DechinPhy

    更多原著文章请参考:https://www.cnblogs.com/dechinphy/

    打赏专用链接:https://www.cnblogs.com/dechinphy/gallery/image/379634.html

    腾讯云专栏同步:https://cloud.tencent.com/developer/column/91958

    参考链接

    1. https://gitee.com/mindspore/mindspore/blob/master/mindspore/python/mindspore/ops/functional.py#L845
  • 相关阅读:
    聊聊 Java 数据结构与算法中的堆最小堆和最大堆
    git创建仓库并建立远程连接
    STC8G1K08A-36I-SOP8 实验版 A1
    Enhancing Quality for HEVC Compressed Videos
    解锁区块链游戏数据解决方案
    Perfect matching
    深分页问题,mysql查询 limit 1000,10 和limit 10 一样快吗?
    2024年腾讯云4核8G云服务器性能测评和优惠价格表
    按钮倒计时提醒
    【iOS】—— 系统手势操作
  • 原文地址:https://www.cnblogs.com/dechinphy/p/ms-vmap.html