• 字节跳动面试真题-寻找两个正序数组的中位数


    题目描述

    给定两个大小分别为 m 和 n 的正序(从小到大)数组 nums1 和 nums2。请你找出并返回这两个正序数组的中位数

    算法的时间复杂度应该为 O(log (m+n)) 。

    示例1:
    输入:nums1 = [1,3], nums2 = [2]
    输出:2.00000
    解释:合并数组 = [1,2,3] ,中位数 2

    示例2:
    输入:nums1 = [1,2], nums2 = [3,4]
    输出:2.50000
    解释:合并数组 = [1,2,3,4] ,中位数 (2 + 3) / 2 = 2.5

    提示:

    nums1.length == m
    nums2.length == n
    0 <= m <= 1000
    0 <= n <= 1000
    1 <= m + n <= 2000
    -106 <= nums1[i], nums2[i] <= 106

    题解

    暴力法

    思路

    这是一道典型的数组类题目,可以先利用归并排序将两个有序数组合并。然后根据奇数还是偶数来返回中位数。

    上述思路需要合并数组,会带来内存空间的开销,其实我们并不需要真的合并数组,而是只要找到中位数在哪里就可以了。由于两个数组长度已知,因此中位数的下表也是已知的,我们只需要维护两个指针,分别指向两个数组下标为0的位置,通过移动指针,找到对应的中位数即可。

    上面两种思路简单易懂,但是时间复杂度并不满足要求,他们都需要O(m+n)的时间复杂度。但是题目要求我们实现时间复杂度log(m+n),通过读题目我们找到关键词:正序数组 。因此我们可以考虑二分查找的方式。如果对时间复杂度的要求有 log,通常都需要用到二分查找,这道题也可以通过二分查找实现。

    根据中位数的定义,当 m+n 是奇数时,中位数是两个有序数组中的第 (m+n)/2个元素,当 m+n 是偶数时,中位数是两个有序数组中的第 (m+n)/2个元素和第 (m+n)/2+1个元素的平均值。因此,这道题可以转化成寻找两个有序数组中的第 k 小的数,其中 k 为 (m+n)/2 或 (m+n)/2+1。

    假设两个数组分别是A和B。要找到第k个元素,我们可以比较 A[k/2 - 1]和 B[k/2 - 1],其中 / 表示整数除法。我们可以归纳出三种情况:

    • 如果A[k/2 - 1] < B[k/2 - 1],则比A[k/2 − 1] 小的数最多只有 A 的前 k/2-1 个数和 B 的前 k/2-1 个数,即比 A[k/2−1] 小的数最多只有 k-2 个,因此 A[k/2−1] 不可能是第 k 个数,A[0] 到 A[k/2−1] 也都不可能是第 k 个数,可以全部排除。

    • 如果 A[k/2−1]>B[k/2−1],则可以排除 B[0] 到 B[k/2−1]。

    • 如果 A[k/2−1]=B[k/2−1],则可以归入第一种情况处理。

    我们可以看到,比较 A[k/2−1] 和 B[k/2−1] 之后,可以排除 k/2 个不可能是第 k 小的数,查找范围缩小了一半。同时,我们将在排除后的新数组上继续进行二分查找,并且根据我们排除数的个数,减少 k 的值,这是因为我们排除的数都不大于第 k 小的数。

    有以下三种情况需要特殊处理:

    如果 A[k/2−1] 或者 B[k/2−1] 越界,那么我们可以选取对应数组中的最后一个元素。在这种情况下,我们必须根据排除数的个数减少 k 的值,而不能直接将 k 减去 k/2。

    如果一个数组为空,说明该数组中的所有元素都被排除,我们可以直接返回另一个数组中第 k 小的元素。

    如果 k=1,我们只要返回两个数组首元素的最小值即可。

    我们举个例子来说明上述算法思路,假设有两个有序数组如下所示:

    A: 1 3 4 9
    B: 1 2 3 4 5 6 7 8 9

    两个有序数组的长度分别是 4 和 9,长度之和是 13,中位数是两个有序数组中的第 7 个元素,因此需要找到第 k [k=7] 个元素。

    比较两个有序数组中下标为 k/2-1=2 的数,即A[2] 和 B[2],如下面所示:

    A: 1 3 4 9
    B: 1 2 3 4 5 6 7 8 9

    由于 A[2]>B[2],因此排除 B[0] 到 B[2],即数组 B 的下标偏移(offset)变为 3,同时更新 k 的值:k=k-k/2=4。

    下一步寻找,比较两个有序数组中下标为 k/2-1=1 的数,即 A[1] 和 B[4],如下面所示,其中方括号部分表示已经被排除的数。

    A: 1 3 4 9
    B: [1 2 3] 4 5 6 7 8 9

    由于 A[2]=B[3],根据之前的规则,排除 A 中的元素,因此排除 A[2],即数组 A 的下标偏移变为 3,同时更新 k 的值: k=k-k/2=1。

    由于 k 的值变成 1,因此比较两个有序数组中的未排除下标范围内的第一个数,其中较小的数即为第 k 个数,由于 A[3]>B[3],因此第 k 个数是 B[3]=4。

    A: [1 3 4] 9
    B: [1 2 3] 4 5 6 7 8 9

    代码

    Go

    func findMedianSortedArrays(nums1 []int, nums2 []int) float64 {
        totalLength := len(nums1) + len(nums2)
        if totalLength%2 == 1 {
            midIndex := totalLength/2
            return float64(getKthElement(nums1, nums2, midIndex + 1))
        } else {
            midIndex1, midIndex2 := totalLength/2 - 1, totalLength/2
            return float64(getKthElement(nums1, nums2, midIndex1 + 1) + getKthElement(nums1, nums2, midIndex2 + 1)) / 2.0
        }
        return 0
    }
    
    func getKthElement(nums1, nums2 []int, k int) int {
        index1, index2 := 0, 0
        for {
            if index1 == len(nums1) {
                return nums2[index2 + k - 1]
            }
            if index2 == len(nums2) {
                return nums1[index1 + k - 1]
            }
            if k == 1 {
                return min(nums1[index1], nums2[index2])
            }
            half := k/2
            newIndex1 := min(index1 + half, len(nums1)) - 1
            newIndex2 := min(index2 + half, len(nums2)) - 1
            pivot1, pivot2 := nums1[newIndex1], nums2[newIndex2]
            if pivot1 <= pivot2 {
                k -= (newIndex1 - index1 + 1)
                index1 = newIndex1 + 1
            } else {
                k -= (newIndex2 - index2 + 1)
                index2 = newIndex2 + 1
            }
        }
        return 0
    }
    
    func min(x, y int) int {
        if x < y {
            return x
        }
        return y
    }
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46

    Python

    class Solution:
        def findMedianSortedArrays(self, nums1: List[int], nums2: List[int]) -> float:
            def getKthElement(k):
                """
                - 主要思路:要找到第 k (k>1) 小的元素,那么就取 pivot1 = nums1[k/2-1] 和 pivot2 = nums2[k/2-1] 进行比较
                - 这里的 "/" 表示整除
                - nums1 中小于等于 pivot1 的元素有 nums1[0 .. k/2-2] 共计 k/2-1 个
                - nums2 中小于等于 pivot2 的元素有 nums2[0 .. k/2-2] 共计 k/2-1 个
                - 取 pivot = min(pivot1, pivot2),两个数组中小于等于 pivot 的元素共计不会超过 (k/2-1) + (k/2-1) <= k-2 个
                - 这样 pivot 本身最大也只能是第 k-1 小的元素
                - 如果 pivot = pivot1,那么 nums1[0 .. k/2-1] 都不可能是第 k 小的元素。把这些元素全部 "删除",剩下的作为新的 nums1 数组
                - 如果 pivot = pivot2,那么 nums2[0 .. k/2-1] 都不可能是第 k 小的元素。把这些元素全部 "删除",剩下的作为新的 nums2 数组
                - 由于我们 "删除" 了一些元素(这些元素都比第 k 小的元素要小),因此需要修改 k 的值,减去删除的数的个数
                """
                
                index1, index2 = 0, 0
                while True:
                    # 特殊情况
                    if index1 == m:
                        return nums2[index2 + k - 1]
                    if index2 == n:
                        return nums1[index1 + k - 1]
                    if k == 1:
                        return min(nums1[index1], nums2[index2])
    
                    # 正常情况
                    newIndex1 = min(index1 + k // 2 - 1, m - 1)
                    newIndex2 = min(index2 + k // 2 - 1, n - 1)
                    pivot1, pivot2 = nums1[newIndex1], nums2[newIndex2]
                    if pivot1 <= pivot2:
                        k -= newIndex1 - index1 + 1
                        index1 = newIndex1 + 1
                    else:
                        k -= newIndex2 - index2 + 1
                        index2 = newIndex2 + 1
            
            m, n = len(nums1), len(nums2)
            totalLength = m + n
            if totalLength % 2 == 1:
                return getKthElement((totalLength + 1) // 2)
            else:
                return (getKthElement(totalLength // 2) + getKthElement(totalLength // 2 + 1)) / 2
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43

    复杂度分析

    • 时间复杂度:O(log(m+n)),其中 m 和 n 分别是数组 nums_1
      和 nums_2的长度。初始时有 k=(m+n)/2 或 k=(m+n)/2+1,每一轮循环可以将查找范围减少一半,因此时间复杂度是 O(log(m+n))。

    • 空间复杂度 :O(1)

  • 相关阅读:
    java源码系列:HashMap源码验证,在JDK8中新增红黑树详解
    Windows下CMD操作常用指令详解
    NDIS小端口驱动开发(三)
    GO编译时避免引入外部动态库的解决方法
    [附源码]Python计算机毕业设计Django现代诗歌交流平台
    灌区流量监测设备:农田灌溉的“智慧眼”
    艾美捷ProSci丨ProSci I kappa B 激酶检测套装解决方案
    【网络篇】第十一篇——简单的TCP英译汉服务器
    jdk1.8.191 JVM内存参数 InitialRAMPercentage和MinRAMPercentage
    Celery笔记四之在Django中使用celery
  • 原文地址:https://blog.csdn.net/u010665216/article/details/125543042