• np.partition介绍


    前言

    这次分享一个 numpy 里面的一个高级函数partition,这个函数在一些搜索、匹配、找相关性的时候会用到。功能强大,但是一般人不知道、不会用,或者不知道怎么用。

    这次就分享一下具体的用法,也是numpy技巧第二篇文章

    同时代码也都是开源的,链接为:https://github.com/yuanzhoulvpi2017/tiny_python/blob/main/numpy_base,文件编号是02开通的。

    介绍

    np.partition是对一个向量、沿着一个维度方向、按照大小对数据进行分堆,分成了两堆。

    1. kth前面的这堆数值,都是这个向量里面比较小的群体们。
    2. kth后面的这堆数值,都是这个向量里面比较大的群体们。
    一个简单的例子:
    data = np.array([232, 564, 278, 3, 2, 1, -1, -10, -30, -40])
    np.partition(data, kth=4)
    
    • 1
    • 2

    np.partition(data, kth=4)的意思就是:

    1. np.partitiondata说:“你们班,现在给我从左向右站好!我不需要你们完全从低到高排序好,我只要左边 4 个是你们里面最小的就行”
    2. 然后data就找到班里最小的 4 个同学:-1, -10, -30, -40,说:“你们几个赶快给我站到左边,不需要你们几个再排序了,怎么快怎么来”
    3. -1, -10, -30, -40听到data指令后,马上跑到左边站好。
    4. 剩下的没有被指出来的,想怎么站都无所谓
    与此类似的
    1. np.partition功能和np.argpartition功能是一样的,只不过np.argpartition返回的是序号
    np.argpartition(data, kth=4)
    
    • 1

    np.argpartition的意思就是:

    1. np.argpartitiondata说:“把你们班最差的几个人学号放在前 4 个坑,剩下人随便填上就行了。
    2. data-1, -10, -30, -40说,你们几个差学生也别出来了,就对我说你们学号多少,我来填上,然后-1说我是 6 号,-10说我说 7 号,-30说我是 8 号,-40说我是 9 号。
    3. 然后剩下的人的编号,data敷衍了事,随便写上了。
    解决问题

    那么我们要想找到data的最小的 4 个数字,其实非常简单。两个方法:

    # way 1
    np.partition(data, kth=4)[:4]
    
    # way 2
    data[np.argpartition(data, kth=4)[:4]]
    
    • 1
    • 2
    • 3
    • 4
    • 5

    推广扩展

    那么问题来了,我想找到data最大的 3 位数怎么办?

    1. 一般第 k 个,我们在 python 里面都是使用 k 为正数,也就是从左向右数第 k 个。
    2. 在 python 取后 k 个,其实我们都是知道的,就是使用-k这个方法。

    那么这个方法其实在这里也是适用的,下面就是解决方法。就不过多做解释了。(注意返回的结果的后三个数值)

    np.partition(data, kth=-3)# 返回的是具体的值
    
    np.argpartition(data, kth=-3) # 返回的是值对应的序号
    
    • 1
    • 2
    • 3

    我要是取后 3 个序号,其实就是 top3 的值了:

    # way1
    np.partition(data, kth=-3)[-3:]
    
    # way 2
    data[np.argpartition(data, kth=-3)[-3:]]
    
    • 1
    • 2
    • 3
    • 4
    • 5

    更高维度怎么办

    上面的data只是一维的,对于二维及更高维度的数据同样适用。
    这里要注意一个小细节:

    1. 假设一个数组的 shape 是: (m,n,z)
      那么axis=1的方向其实就是沿着第二个也就是n这个方向。希望可以帮助读者分清楚.

    实际问题

    假如你是一个研究疾病的研究生,手上有个数据:

    1. 有一系列数据:其中有 425 个疾病名称,有 13426 个症状
    2. 还有一个疾病名称症状权重matrix,矩阵的 shape 为425 x 13426

    需要解决的问题是:
    需要按照权重matrix找到每个疾病名称前 10 个最相关的症状,并且记录下来。

    解析
    1. 这里需要处理的数组变成了二维数组,找 top10(不需要排序,只要找到),并且记录下来。
    2. 这里使用np.argpartition可以一次性将所有的 topk 找出来,大大的提高了计算效率
    这里分享代码
    
    import numpy as np
    import pandas as pd
    from tqdm import tqdm
    
    # generate sample data
    n_features = 13426
    n_disease = 425
    features = [f"feature_{i}" for i in range(n_features)]
    disease = [f"disease_{i}" for i in range(n_disease)]
    weights = np.random.random((n_disease, n_features))
    
    
    #function
    
    def getdata(top_k: int) -> pd.DataFrame:
        index = np.argpartition(weights, -top_k, axis=1)[:, -top_k:]
    
        def slice_data(i):
            temp_data = pd.DataFrame({
                'features': np.array(features)[index[i, :]]})
            temp_data['disease'] = disease[i]
            temp_data['weights'] = weights[i, index[i, :]]
            return temp_data
    
        res = pd.concat([slice_data(i) for i in tqdm(range(weights.shape[0]))]).reset_index(drop=True)
        return res
    
    
    final_data = getdata(top_k=3) # 这里只是找top3的,要是想找top10的,修改数值就行了
    final_data.shape
    final_data.head(4)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    结果如下

  • 相关阅读:
    基于微信药店药品商城小程序系统设计与实现 开题报告
    用HTTP服务的方式集成 learned cardinality estimate 方法进 Postgresql
    【JVM】垃圾回收机制中,对象进入老年代的触发条件
    vue中的数据依赖如何追踪收集
    GPU驱动及CUDA安装流程介绍
    GO语言篇之文件操作
    python开发工具PyCharm使用教程:安装
    SPDK LVOL +Blobstore +FIO 使用
    XSS高级 svg 复现一个循环问题以及两个循环问题
    【LeetCode】11. 盛最多水的容器
  • 原文地址:https://blog.csdn.net/yuanzhoulvpi/article/details/126902156