这次分享一个 numpy 里面的一个高级函数partition
,这个函数在一些搜索、匹配、找相关性的时候会用到。功能强大,但是一般人不知道、不会用,或者不知道怎么用。
这次就分享一下具体的用法,也是numpy技巧第二篇文章
。
同时代码也都是开源的,链接为:https://github.com/yuanzhoulvpi2017/tiny_python/blob/main/numpy_base,文件编号是02
开通的。
np.partition
是对一个向量、沿着一个维度方向、按照大小对数据进行分堆,分成了两堆。
kth
前面的这堆数值,都是这个向量里面比较小的群体们。kth
后面的这堆数值,都是这个向量里面比较大的群体们。data = np.array([232, 564, 278, 3, 2, 1, -1, -10, -30, -40])
np.partition(data, kth=4)
np.partition(data, kth=4)
的意思就是:
np.partition
对data
说:“你们班,现在给我从左向右站好!我不需要你们完全从低到高排序好,我只要左边 4 个是你们里面最小的就行”data
就找到班里最小的 4 个同学:-1
, -10
, -30
, -40
,说:“你们几个赶快给我站到左边,不需要你们几个再排序了,怎么快怎么来”-1
, -10
, -30
, -40
听到data
指令后,马上跑到左边站好。np.partition
功能和np.argpartition
功能是一样的,只不过np.argpartition
返回的是序号np.argpartition(data, kth=4)
np.argpartition
的意思就是:
np.argpartition
对data
说:“把你们班最差的几个人学号放在前 4 个坑,剩下人随便填上就行了。data
对-1
, -10
, -30
, -40
说,你们几个差学生也别出来了,就对我说你们学号多少,我来填上,然后-1
说我是 6 号,-10
说我说 7 号,-30
说我是 8 号,-40
说我是 9 号。data
敷衍了事,随便写上了。那么我们要想找到data
的最小的 4 个数字,其实非常简单。两个方法:
# way 1
np.partition(data, kth=4)[:4]
# way 2
data[np.argpartition(data, kth=4)[:4]]
那么问题来了,我想找到data
最大的 3 位数怎么办?
-k
这个方法。那么这个方法其实在这里也是适用的,下面就是解决方法。就不过多做解释了。(注意返回的结果的后三个数值)
np.partition(data, kth=-3)# 返回的是具体的值
np.argpartition(data, kth=-3) # 返回的是值对应的序号
我要是取后 3 个序号,其实就是 top3 的值了:
# way1
np.partition(data, kth=-3)[-3:]
# way 2
data[np.argpartition(data, kth=-3)[-3:]]
上面的data
只是一维的,对于二维及更高维度的数据同样适用。
这里要注意一个小细节:
axis=1
的方向其实就是沿着第二个也就是n
这个方向。希望可以帮助读者分清楚.假如你是一个研究疾病的研究生,手上有个数据:
疾病名称
,有 13426 个症状
疾病名称
和症状
的权重matrix
,矩阵的 shape 为425 x 13426
需要解决的问题是:
需要按照权重matrix
找到每个疾病名称
前 10 个最相关的症状
,并且记录下来。
np.argpartition
可以一次性将所有的 topk 找出来,大大的提高了计算效率
import numpy as np
import pandas as pd
from tqdm import tqdm
# generate sample data
n_features = 13426
n_disease = 425
features = [f"feature_{i}" for i in range(n_features)]
disease = [f"disease_{i}" for i in range(n_disease)]
weights = np.random.random((n_disease, n_features))
#function
def getdata(top_k: int) -> pd.DataFrame:
index = np.argpartition(weights, -top_k, axis=1)[:, -top_k:]
def slice_data(i):
temp_data = pd.DataFrame({
'features': np.array(features)[index[i, :]]})
temp_data['disease'] = disease[i]
temp_data['weights'] = weights[i, index[i, :]]
return temp_data
res = pd.concat([slice_data(i) for i in tqdm(range(weights.shape[0]))]).reset_index(drop=True)
return res
final_data = getdata(top_k=3) # 这里只是找top3的,要是想找top10的,修改数值就行了
final_data.shape
final_data.head(4)