• 拒绝for循环,从take_along_axis开始


    技术背景

    在前一篇文章中,我们提到了关于Numpy中的各种取index的方法,可以用于取出数组里面的元素,也可以用于做切片,甚至可以用来做排序。但是遇到对于高维矩阵的某一个维度取多个值的时候,单纯的使用下标已经无法完成相关的操作了。如果找不到相应的接口,对于性能要求不高的场景可以使用一个for循环进行替代,但是对于性能要求比较高的场景下,我们还是尽可能的使用Numpy本身自带的接口,比如本文将要提到的take_along_axis操作。

    使用案例

    我们考虑这样的一个场景,给定一个维度为(4,11,3)的矩阵a作为数据,和一个维度为(4,2)的矩阵b作为下标,意味着从a中第二条轴的11个元素中每次取两个元素,也就是希望得到一个维度为(4,2,3)的结果:

    In [11]: a = np.arange(132).reshape((4,11,3))
    
    In [12]: a
    Out[12]: 
    array([[[  0,   1,   2],
            [  3,   4,   5],
            [  6,   7,   8],
            [  9,  10,  11],
            [ 12,  13,  14],
            [ 15,  16,  17],
            [ 18,  19,  20],
            [ 21,  22,  23],
            [ 24,  25,  26],
            [ 27,  28,  29],
            [ 30,  31,  32]],
    
           [[ 33,  34,  35],
            [ 36,  37,  38],
            [ 39,  40,  41],
            [ 42,  43,  44],
            [ 45,  46,  47],
            [ 48,  49,  50],
            [ 51,  52,  53],
            [ 54,  55,  56],
            [ 57,  58,  59],
            [ 60,  61,  62],
            [ 63,  64,  65]],
    
           [[ 66,  67,  68],
            [ 69,  70,  71],
            [ 72,  73,  74],
            [ 75,  76,  77],
            [ 78,  79,  80],
            [ 81,  82,  83],
            [ 84,  85,  86],
            [ 87,  88,  89],
            [ 90,  91,  92],
            [ 93,  94,  95],
            [ 96,  97,  98]],
    
           [[ 99, 100, 101],
            [102, 103, 104],
            [105, 106, 107],
            [108, 109, 110],
            [111, 112, 113],
            [114, 115, 116],
            [117, 118, 119],
            [120, 121, 122],
            [123, 124, 125],
            [126, 127, 128],
            [129, 130, 131]]])
    
    In [13]: b = np.array([[0,1],[1,2],[2,3],[3,4]])
    
    In [14]: b
    Out[14]: 
    array([[0, 1],
           [1, 2],
           [2, 3],
           [3, 4]])
    

    为了方便展示我们就定义了这样两个比较简单的矩阵a和b,那么在这个结果中,我们理想的结果应该是:

    [[[  0,   1,   2],
      [  3,   4,   5]],
    
     [[ 36,  37,  38],
      [ 39,  40,  41]],
    
     [[ 72,  73,  74],
      [ 75,  76,  77]],
    
     [[108, 109, 110],
      [111, 112, 113]]]
    

    这样的一个矩阵。关于这个结果的来源,可以对b这个定义进行展开解释,b的值为:

    [[0, 1],
     [1, 2],
     [2, 3],
     [3, 4]]
    

    它所表示的是在a[0]下取第0个元素和第1个元素,在a[1]下取第1个元素和第2个元素,以此类推。然而如果我们直接把定义好的b放到a的索引中或者直接使用numpy.take的方法的话,得到的结果是这样的:

    In [16]: a[:,b]
    Out[16]: 
    array([[[[  0,   1,   2],
             [  3,   4,   5]],
    
            [[  3,   4,   5],
             [  6,   7,   8]],
    
            [[  6,   7,   8],
             [  9,  10,  11]],
    
            [[  9,  10,  11],
             [ 12,  13,  14]]],
    
    
           [[[ 33,  34,  35],
             [ 36,  37,  38]],
    
            [[ 36,  37,  38],
             [ 39,  40,  41]],
    
            [[ 39,  40,  41],
             [ 42,  43,  44]],
    
            [[ 42,  43,  44],
             [ 45,  46,  47]]],
    
    
           [[[ 66,  67,  68],
             [ 69,  70,  71]],
    
            [[ 69,  70,  71],
             [ 72,  73,  74]],
    
            [[ 72,  73,  74],
             [ 75,  76,  77]],
    
            [[ 75,  76,  77],
             [ 78,  79,  80]]],
    
    
           [[[ 99, 100, 101],
             [102, 103, 104]],
    
            [[102, 103, 104],
             [105, 106, 107]],
    
            [[105, 106, 107],
             [108, 109, 110]],
    
            [[108, 109, 110],
             [111, 112, 113]]]])
    

    显然这不是我们想要的结果。需要额外申明的是,这个执行操作中,最后一个维度的冒号加与不加是一样的效果,跟numpy.take本质上也是同样的操作,因此就需要使用到numpy中的另外一个接口:take_along_axis,如下是其官方的API文档:

    还有相关的使用案例:

    需要注意的是,输入的indices必须要跟原始的数据矩阵保持同样的维度,因此在我们自己的案例中,对b进行了扩维,最终的代码如下所示:

    In [23]: np.take_along_axis(a,b[:,:,None],axis=1)
    Out[23]: 
    array([[[  0,   1,   2],
            [  3,   4,   5]],
    
           [[ 36,  37,  38],
            [ 39,  40,  41]],
    
           [[ 72,  73,  74],
            [ 75,  76,  77]],
    
           [[108, 109, 110],
            [111, 112, 113]]])
    

    最后得到的就是我们想要的结果了,并且是直接使用下标无法实现的操作(当然,也可能是我还没研究出来这样的操作)。这里axis设置为1,就表示a的第0个维度和b的第0个维度是一致的取法,也可以理解成全取的意思。

    总结概要

    Numpy是在Python中用于各种矩阵运算非常强大的工具之一,而快速的通过下标取出所需位置的元素也是numpy所支持的强大功能之一。常规的元素取法都可以通过numpy的下标或者是numpy.take函数来实现,比如array[0,:]可用于取第一条轴的所有元素,array[:,0]可以用于取第二条轴的所有第二个元素,放在一个2维的矩阵里面就分别是取第一行的所有元素和取第一列的所有元素。但是本文更加关注于更高维的矩阵,当我们想从多个维度中取多个元素时,是不太容易直接用下标去取的,比如同时取a[0][0],a[0][1],a[1][1],a[1][2]的话,那么就只能使用numpy所支持的另外一个函数numpy.take_along_axis来实现。

    版权声明

    本文首发链接为:https://www.cnblogs.com/dechinphy/p/take_along_axis.html

    作者ID:DechinPhy

    更多原著文章请参考:https://www.cnblogs.com/dechinphy/

    打赏专用链接:https://www.cnblogs.com/dechinphy/gallery/image/379634.html

    腾讯云专栏同步:https://cloud.tencent.com/developer/column/91958

    参考链接

    1. https://numpy.org/doc/stable/reference/generated/numpy.take_along_axis.html#numpy.take_along_axis
  • 相关阅读:
    生产环境重大bug,update加上索引字段会走索引进行更新?还是走全表扫描
    Optional源码解析与实践
    李宏毅机器学习笔记:RNN循环神经网络
    瑞吉外卖实战项目全攻略——总结篇
    linux性能测试
    清理mac苹果电脑磁盘软件有哪些免费实用的?
    C++的奇妙之旅
    基于象鼻虫损害优化算法求解单目标无约束问题并可视化分析(Matlab代码实现)
    (Java)心得:LeetCode——16.最接近的三数之和
    学生家乡网页设计作品静态HTML网页—— HTML+CSS+JavaScript制作辽宁沈阳家乡主题网页源码(11页)
  • 原文地址:https://www.cnblogs.com/dechinphy/p/take_along_axis.html