• 水塘抽样(应用场景+算法步骤+算法证明+Python实现)


    应用场景

    在以流形式获取 n n n 个元素中,随机抽取 k k k 个样本,满足 k ≤ n k \le n kn(通常有 k ≪ n k \ll n kn)要求:

    • 每个元素有相同的概率被选中;
    • 时间复杂度不超过 O ( n ) O(n) O(n),即遍历一次;空间复杂度不超过 O ( k ) O(k) O(k),即要抽取的样本数。
    算法步骤
    • 初始化长度为 k k k 的缓存列表 L L L
    • 遍历第 i i i 个元素:
      • 如果 i ≤ k i \le k ik,则将该元素存入缓存列表;
      • 如果 i > k i>k i>k,则在 [ 0 , 1 ] [0,1] [0,1] 中取随机数 t t t,若 t < k i t < \frac{k}{i} t<ik,则保留第 i i i 个元素(保留概率为 k i \frac{k}{i} ik),并在缓存列表 L L L 中随机丢掉一个元素(每个元素被丢掉的概率为 1 i \frac{1}{i} i1);
    • 遍历完所有元素后,缓存列表 L L L 即为抽取的样本。
    算法证明
    1. 每个元素有相同的概率被选中

    证明 使用数学归纳法。当 n = k n = k n=k 时,每个元素被选中的概率均为 1 1 1,结论显然成立。假设当 n = m n=m n=m m > = k m >= k m>=k)时结论成立,只需证明当 n = m + 1 n = m + 1 n=m+1 时结论仍成立。

    因为当 n = m n = m n=m 时结论成立,所以在遍历第 m + 1 m+1 m+1 个元素前,前 m m m 个元素被选中的概率相同。因为前 m m m 个元素被选中的概率相同,所以每个元素被选中的概率为
    P m ( i ) = k ÷ m = k m ( i = 1 , 2 , ⋯   , m ) (1) P_m(i) = k \div m = \frac{k}{m} \hspace{1em} (i = 1,2,\cdots,m) \tag{1} Pm(i)=k÷m=mk(i=1,2,,m)(1)
    在遍历第 m + 1 m+1 m+1 个元素时,有 k m + 1 \frac{k}{m+1} m+1k 的概率会保留第 m + 1 m+1 m+1 个元素,所以第 m + 1 m+1 m+1 个元素被选中的概率为
    P m + 1 ( m + 1 ) = k m + 1 (2) P_{m+1}(m+1) = \frac{k}{m+1} \tag{2} Pm+1(m+1)=m+1k(2)
    在遍历第 m + 1 m+1 m+1 个元素时,有 k m + 1 \frac{k}{m+1} m+1k 的概率会在前 m m m 个元素选出的 k k k 个样本中随机丢掉一个,其中每个样本被丢掉的概率为
    k m + 1 × 1 k = 1 m + 1 \frac{k}{m+1} \times \frac{1}{k} = \frac{1}{m+1} m+1k×k1=m+11
    所以,将式 ( 1 ) (1) (1) 代入可得,在遍历第 m + 1 m+1 m+1 个元素后,前 m m m 个元素被选中的概率为
    P m + 1 ( i ) = P m ( i ) × ( 1 − 1 m + 1 ) = k m × m m + 1 = k m + 1 ( i = 1 , 2 , ⋯   , m + 1 ) (3) P_{m+1}(i) = P_m(i) \times (1 - \frac{1}{m+1}) = \frac{k}{m} \times \frac{m}{m+1} = \frac{k}{m+1} \hspace{1em} (i = 1,2,\cdots,m+1) \tag{3} Pm+1(i)=Pm(i)×(1m+11)=mk×m+1m=m+1k(i=1,2,,m+1)(3)
    综上所述,根据式 ( 2 ) (2) (2) 和式 ( 3 ) (3) (3) 可知,若当 n = m n=m n=m m > = k m >= k m>=k)时结论成立,则当 n = m + 1 n = m+1 n=m+1 时,前 m m m 个元素被选中的概率与第 m + 1 m+1 m+1 个元素被选中的概率相同,均为 k m + 1 \frac{k}{m+1} m+1k。得证。

    2. 时间复杂度不超过 O ( n ) O(n) O(n),空间复杂度不超过 O ( k ) O(k) O(k)

    根据算法步骤显然可知:

    • 时间复杂度为 O ( n ) O(n) O(n);遍历了所有元素;
    • 空间复杂度为 O ( k ) O(k) O(k),为缓存列表 L L L 的长度为 k k k

    满足要求。

    代码实现

    代码满足 PEP-0008、PEP-0484、pylint 规范;使用 numpy 注释文档规范。

    import random
    from typing import Any, Iterator, List
    
    
    def reservoir_sampling(population: Iterator[Any], n_sample: int):
        """水塘抽样
    
        Parameters
        ----------
        population : Iterator[Any]
            所有元素的可迭代对象(流状态)
        n_sample : int
            需要选取的样本数
    
        Returns
        -------
        sample : List[Any]
            选中的样本列表
        """
        cache: List[Any] = []  # 缓存列表
        i = 0
        for elem in population:
            if i < n_sample:
                cache.append(elem)
            elif random.random() < n_sample / i:
                cache[random.randint(0, n_sample - 1)] = elem
            i += 1
        return cache
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28

    测试用例

    reservoir_sampling(range(5), 5)  # [0, 1, 2, 3, 4]
    reservoir_sampling(range(1000), 5)  # [385, 921, 590, 492, 243]
    
    • 1
    • 2
  • 相关阅读:
    mysql的行锁和间隙锁
    .NET Emit 入门教程:第七部分:实战项目1:将 DbDataReader 转实体
    简单工厂模式
    挂载硬盘相关操作-linux004
    【scikit-learn基础】--『预处理』之 数据缩放
    【每日一题Day358】LC2698求一个整数的惩罚数 | 递归
    python 基础语法学习 (二)
    Nautlius Chain主网正式上线,模块Layer3时代正式开启
    Himall商城字符串帮助类获得指定顺序的字符在字符串中的位置索引
    HTTP协议
  • 原文地址:https://blog.csdn.net/Changxing_J/article/details/128185299