在以流形式获取 n n n 个元素中,随机抽取 k k k 个样本,满足 k ≤ n k \le n k≤n(通常有 k ≪ n k \ll n k≪n)要求:
证明 使用数学归纳法。当 n = k n = k n=k 时,每个元素被选中的概率均为 1 1 1,结论显然成立。假设当 n = m n=m n=m( m > = k m >= k m>=k)时结论成立,只需证明当 n = m + 1 n = m + 1 n=m+1 时结论仍成立。
因为当
n
=
m
n = m
n=m 时结论成立,所以在遍历第
m
+
1
m+1
m+1 个元素前,前
m
m
m 个元素被选中的概率相同。因为前
m
m
m 个元素被选中的概率相同,所以每个元素被选中的概率为
P
m
(
i
)
=
k
÷
m
=
k
m
(
i
=
1
,
2
,
⋯
,
m
)
(1)
P_m(i) = k \div m = \frac{k}{m} \hspace{1em} (i = 1,2,\cdots,m) \tag{1}
Pm(i)=k÷m=mk(i=1,2,⋯,m)(1)
在遍历第
m
+
1
m+1
m+1 个元素时,有
k
m
+
1
\frac{k}{m+1}
m+1k 的概率会保留第
m
+
1
m+1
m+1 个元素,所以第
m
+
1
m+1
m+1 个元素被选中的概率为
P
m
+
1
(
m
+
1
)
=
k
m
+
1
(2)
P_{m+1}(m+1) = \frac{k}{m+1} \tag{2}
Pm+1(m+1)=m+1k(2)
在遍历第
m
+
1
m+1
m+1 个元素时,有
k
m
+
1
\frac{k}{m+1}
m+1k 的概率会在前
m
m
m 个元素选出的
k
k
k 个样本中随机丢掉一个,其中每个样本被丢掉的概率为
k
m
+
1
×
1
k
=
1
m
+
1
\frac{k}{m+1} \times \frac{1}{k} = \frac{1}{m+1}
m+1k×k1=m+11
所以,将式
(
1
)
(1)
(1) 代入可得,在遍历第
m
+
1
m+1
m+1 个元素后,前
m
m
m 个元素被选中的概率为
P
m
+
1
(
i
)
=
P
m
(
i
)
×
(
1
−
1
m
+
1
)
=
k
m
×
m
m
+
1
=
k
m
+
1
(
i
=
1
,
2
,
⋯
,
m
+
1
)
(3)
P_{m+1}(i) = P_m(i) \times (1 - \frac{1}{m+1}) = \frac{k}{m} \times \frac{m}{m+1} = \frac{k}{m+1} \hspace{1em} (i = 1,2,\cdots,m+1) \tag{3}
Pm+1(i)=Pm(i)×(1−m+11)=mk×m+1m=m+1k(i=1,2,⋯,m+1)(3)
综上所述,根据式
(
2
)
(2)
(2) 和式
(
3
)
(3)
(3) 可知,若当
n
=
m
n=m
n=m(
m
>
=
k
m >= k
m>=k)时结论成立,则当
n
=
m
+
1
n = m+1
n=m+1 时,前
m
m
m 个元素被选中的概率与第
m
+
1
m+1
m+1 个元素被选中的概率相同,均为
k
m
+
1
\frac{k}{m+1}
m+1k。得证。
根据算法步骤显然可知:
满足要求。
代码满足 PEP-0008、PEP-0484、pylint 规范;使用 numpy 注释文档规范。
import random
from typing import Any, Iterator, List
def reservoir_sampling(population: Iterator[Any], n_sample: int):
"""水塘抽样
Parameters
----------
population : Iterator[Any]
所有元素的可迭代对象(流状态)
n_sample : int
需要选取的样本数
Returns
-------
sample : List[Any]
选中的样本列表
"""
cache: List[Any] = [] # 缓存列表
i = 0
for elem in population:
if i < n_sample:
cache.append(elem)
elif random.random() < n_sample / i:
cache[random.randint(0, n_sample - 1)] = elem
i += 1
return cache
测试用例:
reservoir_sampling(range(5), 5) # [0, 1, 2, 3, 4]
reservoir_sampling(range(1000), 5) # [385, 921, 590, 492, 243]