给定两个大小分别为 m 和 n 的正序(从小到大)数组 nums1 和 nums2。请你找出并返回这两个正序数组的中位数。
算法的时间复杂度应该为 O(log (m+n)) 。
示例1:
输入:nums1 = [1,3], nums2 = [2]
输出:2.00000
解释:合并数组 = [1,2,3] ,中位数 2
示例2:
输入:nums1 = [1,2], nums2 = [3,4]
输出:2.50000
解释:合并数组 = [1,2,3,4] ,中位数 (2 + 3) / 2 = 2.5
提示:
nums1.length == m
nums2.length == n
0 <= m <= 1000
0 <= n <= 1000
1 <= m + n <= 2000
-106 <= nums1[i], nums2[i] <= 106
这是一道典型的数组类题目,可以先利用归并排序将两个有序数组合并。然后根据奇数还是偶数来返回中位数。
上述思路需要合并数组,会带来内存空间的开销,其实我们并不需要真的合并数组,而是只要找到中位数在哪里就可以了。由于两个数组长度已知,因此中位数的下表也是已知的,我们只需要维护两个指针,分别指向两个数组下标为0的位置,通过移动指针,找到对应的中位数即可。
上面两种思路简单易懂,但是时间复杂度并不满足要求,他们都需要O(m+n)的时间复杂度。但是题目要求我们实现时间复杂度log(m+n),通过读题目我们找到关键词:正序数组 。因此我们可以考虑二分查找的方式。如果对时间复杂度的要求有 log,通常都需要用到二分查找,这道题也可以通过二分查找实现。
根据中位数的定义,当 m+n 是奇数时,中位数是两个有序数组中的第 (m+n)/2个元素,当 m+n 是偶数时,中位数是两个有序数组中的第 (m+n)/2个元素和第 (m+n)/2+1个元素的平均值。因此,这道题可以转化成寻找两个有序数组中的第 k 小的数,其中 k 为 (m+n)/2 或 (m+n)/2+1。
假设两个数组分别是A和B。要找到第k个元素,我们可以比较 A[k/2 - 1]和 B[k/2 - 1],其中 / 表示整数除法。我们可以归纳出三种情况:
如果A[k/2 - 1] < B[k/2 - 1],则比A[k/2 − 1] 小的数最多只有 A 的前 k/2-1 个数和 B 的前 k/2-1 个数,即比 A[k/2−1] 小的数最多只有 k-2 个,因此 A[k/2−1] 不可能是第 k 个数,A[0] 到 A[k/2−1] 也都不可能是第 k 个数,可以全部排除。
如果 A[k/2−1]>B[k/2−1],则可以排除 B[0] 到 B[k/2−1]。
如果 A[k/2−1]=B[k/2−1],则可以归入第一种情况处理。
我们可以看到,比较 A[k/2−1] 和 B[k/2−1] 之后,可以排除 k/2 个不可能是第 k 小的数,查找范围缩小了一半。同时,我们将在排除后的新数组上继续进行二分查找,并且根据我们排除数的个数,减少 k 的值,这是因为我们排除的数都不大于第 k 小的数。
有以下三种情况需要特殊处理:
如果 A[k/2−1] 或者 B[k/2−1] 越界,那么我们可以选取对应数组中的最后一个元素。在这种情况下,我们必须根据排除数的个数减少 k 的值,而不能直接将 k 减去 k/2。
如果一个数组为空,说明该数组中的所有元素都被排除,我们可以直接返回另一个数组中第 k 小的元素。
如果 k=1,我们只要返回两个数组首元素的最小值即可。
我们举个例子来说明上述算法思路,假设有两个有序数组如下所示:
A: 1 3 4 9
B: 1 2 3 4 5 6 7 8 9
两个有序数组的长度分别是 4 和 9,长度之和是 13,中位数是两个有序数组中的第 7 个元素,因此需要找到第 k [k=7] 个元素。
比较两个有序数组中下标为 k/2-1=2 的数,即A[2] 和 B[2],如下面所示:
A: 1 3 4 9
B: 1 2 3 4 5 6 7 8 9
由于 A[2]>B[2],因此排除 B[0] 到 B[2],即数组 B 的下标偏移(offset)变为 3,同时更新 k 的值:k=k-k/2=4。
下一步寻找,比较两个有序数组中下标为 k/2-1=1 的数,即 A[1] 和 B[4],如下面所示,其中方括号部分表示已经被排除的数。
A: 1 3 4 9
B: [1 2 3] 4 5 6 7 8 9
由于 A[2]=B[3],根据之前的规则,排除 A 中的元素,因此排除 A[2],即数组 A 的下标偏移变为 3,同时更新 k 的值: k=k-k/2=1。
由于 k 的值变成 1,因此比较两个有序数组中的未排除下标范围内的第一个数,其中较小的数即为第 k 个数,由于 A[3]>B[3],因此第 k 个数是 B[3]=4。
A: [1 3 4] 9
B: [1 2 3] 4 5 6 7 8 9
func findMedianSortedArrays(nums1 []int, nums2 []int) float64 {
totalLength := len(nums1) + len(nums2)
if totalLength%2 == 1 {
midIndex := totalLength/2
return float64(getKthElement(nums1, nums2, midIndex + 1))
} else {
midIndex1, midIndex2 := totalLength/2 - 1, totalLength/2
return float64(getKthElement(nums1, nums2, midIndex1 + 1) + getKthElement(nums1, nums2, midIndex2 + 1)) / 2.0
}
return 0
}
func getKthElement(nums1, nums2 []int, k int) int {
index1, index2 := 0, 0
for {
if index1 == len(nums1) {
return nums2[index2 + k - 1]
}
if index2 == len(nums2) {
return nums1[index1 + k - 1]
}
if k == 1 {
return min(nums1[index1], nums2[index2])
}
half := k/2
newIndex1 := min(index1 + half, len(nums1)) - 1
newIndex2 := min(index2 + half, len(nums2)) - 1
pivot1, pivot2 := nums1[newIndex1], nums2[newIndex2]
if pivot1 <= pivot2 {
k -= (newIndex1 - index1 + 1)
index1 = newIndex1 + 1
} else {
k -= (newIndex2 - index2 + 1)
index2 = newIndex2 + 1
}
}
return 0
}
func min(x, y int) int {
if x < y {
return x
}
return y
}
class Solution:
def findMedianSortedArrays(self, nums1: List[int], nums2: List[int]) -> float:
def getKthElement(k):
"""
- 主要思路:要找到第 k (k>1) 小的元素,那么就取 pivot1 = nums1[k/2-1] 和 pivot2 = nums2[k/2-1] 进行比较
- 这里的 "/" 表示整除
- nums1 中小于等于 pivot1 的元素有 nums1[0 .. k/2-2] 共计 k/2-1 个
- nums2 中小于等于 pivot2 的元素有 nums2[0 .. k/2-2] 共计 k/2-1 个
- 取 pivot = min(pivot1, pivot2),两个数组中小于等于 pivot 的元素共计不会超过 (k/2-1) + (k/2-1) <= k-2 个
- 这样 pivot 本身最大也只能是第 k-1 小的元素
- 如果 pivot = pivot1,那么 nums1[0 .. k/2-1] 都不可能是第 k 小的元素。把这些元素全部 "删除",剩下的作为新的 nums1 数组
- 如果 pivot = pivot2,那么 nums2[0 .. k/2-1] 都不可能是第 k 小的元素。把这些元素全部 "删除",剩下的作为新的 nums2 数组
- 由于我们 "删除" 了一些元素(这些元素都比第 k 小的元素要小),因此需要修改 k 的值,减去删除的数的个数
"""
index1, index2 = 0, 0
while True:
# 特殊情况
if index1 == m:
return nums2[index2 + k - 1]
if index2 == n:
return nums1[index1 + k - 1]
if k == 1:
return min(nums1[index1], nums2[index2])
# 正常情况
newIndex1 = min(index1 + k // 2 - 1, m - 1)
newIndex2 = min(index2 + k // 2 - 1, n - 1)
pivot1, pivot2 = nums1[newIndex1], nums2[newIndex2]
if pivot1 <= pivot2:
k -= newIndex1 - index1 + 1
index1 = newIndex1 + 1
else:
k -= newIndex2 - index2 + 1
index2 = newIndex2 + 1
m, n = len(nums1), len(nums2)
totalLength = m + n
if totalLength % 2 == 1:
return getKthElement((totalLength + 1) // 2)
else:
return (getKthElement(totalLength // 2) + getKthElement(totalLength // 2 + 1)) / 2
时间复杂度:O(log(m+n)),其中 m 和 n 分别是数组 nums_1
和 nums_2的长度。初始时有 k=(m+n)/2 或 k=(m+n)/2+1,每一轮循环可以将查找范围减少一半,因此时间复杂度是 O(log(m+n))。
空间复杂度 :O(1)