• 【TopK问题】基于堆的方法&基于分治策略的方法


    说明

    1. TopK问题:对于给定的数组,选出其中最大/最小的k个元素,或是选出第k大/第k小元素;
    2. 本文整理了两种实现方法,分别是
      • 基于堆的实现方法:和堆排序有所不同的是,仅仅通过构建含有k个元素的堆,最终得到最大/最小的k个元素
      • 基于分治策略的方法:采用了快速排序的思想,对原数组进行划分,但和快排不同的是,每次仅处理划分后的一边
    3. 文章内容为个人学习整理,如有错误,欢迎指正。

    1. 基于堆的方法

    1.1 算法步骤

    问题要求得到最大的k个元素,就可以构建含有k个元素的小根堆(相应地,若是求最小的k个元素,就构建大根堆)。

    1. 首先利用原数组的前k个元素构建小根堆;
    2. 从原数组的第k+1个元素开始向后遍历,并依次比较元素与堆顶元素大小,若大于堆顶元素则替换堆顶元素并及时调整堆;否则继续向后遍历;
    3. 当数组遍历完毕后,小根堆中存储的k个元素就是原数组中最大的k个元素。

    1.2 算法实现

    LeetCode相关题目:215. 数组中的第K个最大元素

    //使用数组的前k个元素构造含有k个元素的小根堆
        //从k+1开始遍历,每次和堆顶元素比较,若被遍历到的元素大于堆顶元素,则替换堆顶元素并调整堆,保证堆内的k个元素总是当前最大的k个元素。
        int findKthLargest(vector<int>& nums, int k) {
            vector<int> heap_k(nums.begin(), nums.begin()+k); //选取nums中的前k个元素     
            BuildMinHeap(heap_k); //将这k个元素建成小根堆
            
            for(int i=k; i<nums.size(); i++){//从第k+1个元素(下标为k)开始依次和堆顶元素比较
                if(nums[i] > heap_k[0]){
                    heap_k[0] = nums[i]; 
                    MinHeapAdjust(heap_k, 0, k);//若被遍历到的元素比堆顶元素大,则替换堆顶元素并调整堆
                }
            }               
            return heap_k[0];//heap_k是小根堆,heap_k[0]中就是原数组第k大元素
        }
    	//构建小根堆
        void BuildMinHeap(vector<int>& nums){
            int n = nums.size();
            for(int i=n/2; i>=0; i--){//从第一个非叶结点开始调整
                MinHeapAdjust(nums, i, n);
            }
        }
        //调整小根堆
        void MinHeapAdjust(vector<int>& nums, int i, int n){
            int temp = nums[i]; //暂存被筛选的结点
            for(int j=i*2+1; j<n; j=i*2+1){//j初始时指向i结点的左孩子
                if(j+1<n && nums[j+1]<nums[j]) j++;//调整j,使其指向i的左右孩子中的较小值
    
                if(temp <= nums[j]) break;//若当前被筛选结点temp更小,说明自此结点向下度符合小根堆的要求,可以提前终止筛选
                else{
                    nums[i] = nums[j];//否则将孩子结点中的更小者调整到双亲位置上
                    i = j; //更新i指针以便继续向下筛选
                }
            }
            nums[i] = temp; //被筛选结点放在其最终位置
        }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35

    2. 基于分治策略的方法

    2.1 算法步骤

    在快速排序中,最主要的步骤是pivotPos = Partition(nums, left, right),它利用数组中的一个元素作为pivot,将下标从left到right的元素分为两部分,并以pivot为枢轴将比pivot小的元素放在左边,比pivot大的元素放在右边,通过不断划分数组,最终获得整体的排序。

    在TopK问题中也利用了这种不断划分的分治策略,但是在快排中,每次要处理左右两部分,在TopK问题中对这一步骤做了简化,即每次只处理一边。因为要找的是最大/最小的k个元素,因此可以通过比较pivotPosk的大小来判断下一次要处理的是左边还是右边。

    2.2 算法实现

    P.S. 关于RANDOMIZED-SELECT的算法对应《算法导论(第3版)》9.2期望为线性时间的选择算法,关于SELECT的算法对应9.3最坏情况为线性时间的选择算法

    • 基于随机选择的算法实现
    //方法4:分治法,随机选择 
        int findKthLargest(vector<int>& nums, int k) {
            randomizedSelect(nums, 0, nums.size()-1, k);
            return nums[k-1];
        }
    
        //划分:随机选择RANDOMIZED-SELECT
        int Partition(vector<int>& nums, int left, int right){
            int pivotPos = rand()%(right-left+1) + left;//生成[left,right]范围内的随机数
            int pivot = nums[pivotPos];//随机选择元素作为枢轴
    
            swap(nums[left], nums[pivotPos]);//将枢轴元素和最左元素交换,之后将最左元素作为枢轴(算法书中称为主元)
    
            //获得降序序列
            while(left<right){
                while(left<right && nums[right]<=pivot) right--;//因为将最左元素作为枢轴,因此也要先移动右侧指针
                nums[left] = nums[right];
                while(left<right && nums[left]>=pivot) left++;
                nums[right] = nums[left];
            }
            pivotPos = left;//最终左右指针相遇,该位置即为pivot的最终位置
            nums[pivotPos] = pivot;
            return pivotPos;
        }
        //随机选择递归函数
        void randomizedSelect(vector<int>& nums, int left, int right, int k){
            if(left >= right) return;//递归返回条件
    
            int pivotPos = Partition(nums, left, right);
            if(pivotPos == k) return; //找到kth
            else if(pivotPos > k) randomizedSelect(nums, left, pivotPos-1,k);//按降序排列,因此当pivotPos比k大时,说明要找的kth在序列的左侧
            else randomizedSelect(nums, pivotPos+1, right, k);
        }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 最坏情况为线性时间的选择算法 SELECT,以下内容来自《算法导论》
    1. 将输入数组的n个元素划分为n/5组,每组5个元素,且至多只有一组由剩下的n mod 5个元素组成;
    2. 寻找n/5组中每一组的中位数:首先对每组元素进行插入排序,然后确定每组有序元素的中位数;
    3. 对第二步找到的n/5个中位数,递归调用SELECT函数找出其中位数x(若有偶数个中位数,约定x是较小的中位数);
    4. 按中位数的中位数x对数组进行划分,让k比划分的低区中的元素数目多1,因此x是第k小元素,且有n-k个元素在划分的高区;
    5. 若i=k,则返回x;若i.k,则在高区递归查找第i-k小的元素。

    参考学习:
    线性时间选择问题
    BFPRT——Top k问题的终极解法

    选出第k大的元素:

    //找到第k大的元素
    
    //划分函数
    int Partition(vector<int>& nums, int left, int right, int pivot){//pivot是传入的中位数的中位数
        for(int index=left; index<=right; index++){//在left,right范围内寻找pivot的下标
            if(nums[index] == pivot){
                swap(nums[left], nums[index]);//和最左元素交换作为主元
                break;
            }
        }
        //降序序列
        int i=left, j=right;
        while(i<j){
            while(i<j && nums[j]<=pivot) j--;
            while(i<j && nums[i]>=pivot) i++;
            swap(nums[i], nums[j]);
        }
        swap(nums[left], nums[i]);
        return i;
    }
    
    //对[begin,end]范围内数据进行排序,并返回中位数下标
    int indexOfMedian(vector<int>& nums, int begin, int end){
        sort(nums.begin()+begin, nums.begin()+end+1, greater<int>());
        int index = begin + (end - begin)/2;
        return index;
    }
    
    int select(vector<int>& nums, int left, int right, int kth){
        if(right-left+1 <= 5){
            //元素个数在5个以内则直接排序并返回此次的kth
            sort(nums.begin()+left, nums.begin()+right+1, greater<int>());
            return nums[left + kth -1];//注意下标要减一
        }
    
        int count = right - left + 1;
        int groups = count/5 + (count%5 > 0 ? 1 : 0);//总共有多少组
        for(int i=0; i<groups; i++){//i是组号,从0开始,按组遍历
            int index = indexOfMedian(nums, left+i*5, min(left+i*5+4, right));//这里要对最后一组进行处理,最后一组元素个数可能不足5个
            swap(nums[left+i], nums[index]);//将中位数换到数组的前面,方便下一次取数
        }
    
        int pivot = select(nums, left, left+groups-1, groups/2);//中位数的中位数,若个数为偶数,选择较小中位数
        int pivotPos = Partition(nums, left, right, pivot);//按照pivot进行划分
    
        int num_left = pivotPos - left + 1;//[left, pivotPos]之间一共多少个元素
        if(num_left == kth) return nums[pivotPos];//若下标为pivotPos的元素恰好为kth个元素,返回
        else if(num_left > kth) return select(nums, left, pivotPos-1, kth);//kth在左半区
        else return select(nums, pivotPos+1, right, kth-num_left);//kth在右半区,kth-num_left是kth在右半区的相对位置    
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50

    P.S.对SELECT算法自己理解得还不够透彻,以后继续把这一块儿补充一下。

  • 相关阅读:
    MySQL 查询条件
    idea:JavaWeb(maven)Servlet 03
    INTERSPEECH 2022——基于层级上下文语义信息的多尺度语音合成风格建模
    NDP 协议介绍
    Kotlin 泛型
    Go 字符串操作实战
    13个小众有趣的网站,只有程序员才看得懂
    R语言 数据的整理与清洗(Data Frame 篇上)
    TS代码整洁之道(下)
    规则引擎基础知识
  • 原文地址:https://blog.csdn.net/qq_45800517/article/details/134236074