题目:
O((m+n)/2)解法,合并数组,不过还能优化(当一个数组已经为空时,可以直接计算中位数)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
| #include <vector> using std::vector;
class Solution { public: double findMedianSortedArrays(vector<int> &nums1, vector<int> &nums2) { int m = nums1.size(), n = nums2.size(); int last_num;
int mi = 0, ni = 0; int index = 0; int total = m + n; double res = 0;
while (mi < m || ni < n) { if (mi < m && ni < n) { if (nums1[mi] < nums2[ni]) { last_num = nums1[mi++];
} else { last_num = nums2[ni++]; } } else if (mi >= m && ni < n) { last_num = nums2[ni++];
} else if (mi < m && ni >= n) { last_num = nums1[mi++]; }
if (total % 2 == 0) { if(index == total / 2 - 1){ res += last_num; }else if(index == total / 2){ res += last_num; res = (double) res/2; break; } } else { if (index == total / 2) { res = last_num; break; } } index++; } return res; } };
|
二分法:
根据中位数的定义:
- 当
m+n 是奇数时,中位数是两个有序数组中的第 (m+n+1)/2 个元素
- 当
m+n 是偶数时,中位数是两个有序数组中的第 (m+n)/2 个元素和第 (m+n)/2+1 个元素的平均值。
因此,这道题可以转化成寻找两个有序数组中的第 k 小的数,其中 k 为 (m+n)/2 或 (m+n)/2+1。
核心思想就是:每次删除 k/2 个元素(每次排除 k/2 个不可能的元素)
设:pivot1 = nums1[k/2-1] ,pivot2 = nums2[k/2-1], 比较:pivot1 vs pivot2 , 初始 index1 和 index2 均为0
情况1 : pivot1 <= pivot2
说明:nums1[0...k/2-1] 都不可能是第k小,因为nums1 <= pivot1 的最多 k/2 个, nums2 <= pivot2 的最多 k/2-1 个, 总共 <= k-1
因此:nums1前 k/2 个全部删除
更新:index1 += k/2, k -= k/2
情况2:pivot2 < pivot1
- 同理:
nums2 前 k/2 个删除
- 更新:
index2 += k/2, k -= k/2
此外有三种边界情况,循环开始先处理边界:
- 一个数组已经空了
例如:nums1 = [], nums2 = [1,2,3,4], 第 k 小就是:nums2[k-1]
1 2 3 4
| if (index1 == m) return nums2[index2 + k - 1]; if (index2 == n) return nums1[index1 + k - 1];
|
k == 1
第1小 = 两数组最小值
1
| return min(nums1[index1], nums2[index2]);
|
- 正常二分删除
不断减少 k 直到:k == 1
代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
| #include <vector> using std::vector;
class Solution { public: int getKthElement(const vector<int> &nums1, const vector<int> &nums2, int k) {
int m = nums1.size(); int n = nums2.size(); int idx1 = 0, idx2 = 0; int new_idx1, new_idx2, pivot1, pivot2;
while (true) { if (idx1 == m) return nums2[idx2 + k - 1];
if (idx2 == n) return nums1[idx1 + k - 1];
if (k == 1) return std::min(nums1[idx1], nums2[idx2]);
new_idx1 = std::min(idx1 + k / 2 - 1, m - 1); new_idx2 = std::min(idx2 + k / 2 - 1, n - 1); pivot1 = nums1[new_idx1]; pivot2 = nums2[new_idx2];
if (pivot1 <= pivot2) { k -= new_idx1 - idx1 + 1; idx1 = new_idx1 + 1; } else { k -= new_idx2 - idx2 + 1; idx2 = new_idx2 + 1; } } }
double findMedianSortedArrays(vector<int> &nums1, vector<int> &nums2) { int total_len = nums1.size() + nums2.size(); int k, val1, val2;
if (total_len % 2 == 1) { k = (total_len + 1) / 2; return getKthElement(nums1, nums2, k); } else { k = total_len / 2; val1 = getKthElement(nums1, nums2, k);
k = total_len / 2 + 1; val2 = getKthElement(nums1, nums2, k);
return (val1 + val2) / 2.0; } } };
|
hot 100 rewrite: 没写出来哈哈哈,这个二分有点难想到