给定三个整数数组
A = [ A 1 , A 2 , … A N ] , A=[A_1,A_2,…A_N], A=[A1,A2,…AN],
B = [ B 1 , B 2 , … B N ] , B=[B_1,B_2,…B_N], B=[B1,B2,…BN],
C = [ C 1 , C 2 , … C N ] , C=[C_1,C_2,…C_N], C=[C1,C2,…CN],
请你统计有多少个三元组 ( i , j , k ) (i,j,k) (i,j,k) 满足:
第一行包含一个整数 N N N。
第二行包含 N N N 个整数 A 1 , A 2 , … A N 。 A_1,A_2,…A_N。 A1,A2,…AN。
第三行包含 N N N 个整数 B 1 , B 2 , … B N 。 B_1,B_2,…B_N。 B1,B2,…BN。
第四行包含 N N N 个整数 C 1 , C 2 , … C N 。 C_1,C_2,…C_N。 C1,C2,…CN。
一个整数表示答案。
1 ≤ N ≤ 105 , 1≤N≤105, 1≤N≤105,
0 ≤ A i , B i , C i ≤ 105 0≤Ai,Bi,Ci≤105 0≤Ai,Bi,Ci≤105
- 3
- 1 1 1
- 2 2 2
- 3 3 3
27
我自己想的时候是枚举 a a a 组中的每一个元素,然后二分出 b b b 和 c c c 中的边界点,然后相乘。但是自己想错了, 因为 b b b 和 c c c 有关系,所以不能直接相乘。
那么其实可以枚举 b b b 中的每一个点,再根据 b i b_i bi 找到 a , c a,c a,c 中小于 b i bi bi 的数 和 大于 b i bi bi 的数 ,因为如此分, a a a 和 c c c 就是独立存在,可以用乘法原理
如果暴力解决,那么必然是需要 O ( n 3 ) O(n^3) O(n3)的复杂度枚举三个点再判断其大小,必定超时
所以可以想办法优化,可以看出我们找小于和大于 b i b_i bi 的数时,可以
对三个数组排序,枚举 b b b 的每一个点,然后二分出 a a a 中小于 b i b_i bi 的第一个数和 b b b 中大于 b i b_i bi 的第一个数,算出个数相乘即可
O ( n l o g n ) O(nlogn) O(nlogn)
- #include <iostream>
- #include <algorithm>
- using namespace std;
-
- typedef long long LL;
-
- const int N = 100010;
- int q[3][N];
- int n;
-
- int lowbound(int i,int x) {
- int l = 0, r = n - 1;
- while(l < r) {
- int mid = l + r + 1>> 1;
- if(q[i][mid] < x) l = mid;
- else r = mid - 1;
- }
-
- if(q[i][l] >= x) l = -1;
-
- return l;
- }
-
- int upbound(int i,int x) {
- int l = 0, r = n - 1;
- while(l < r) {
- int mid = l + r >> 1;
- if(q[i][mid] > x) r = mid;
- else l = mid + 1;
- }
-
- if(q[i][r] <= x) r = n;
-
- return r;
- }
-
- int main()
- {
- cin >> n;
- for(int i = 0;i < 3;i ++) {
- for(int j = 0;j < n;j ++) scanf("%d",&q[i][j]);
- sort(q[i],q[i] + n);
- }
-
-
- LL res = 0;
- for(int i = 0;i < n;i ++) {
- int x = q[1][i];
- int l1 = lowbound(0,x);
- int l2 = upbound(2,x);
- res += (LL)(l1 + 1) * (n - l2);
- }
- cout << res << endl;
- return 0;
- }
O ( N ) O(N) O(N)
- #include <cstring>
- #include <iostream>
- using namespace std;
-
- typedef long long LL;
-
- const int N = 100010;
- int a[N], b[N], c[N];
- int as[N]; // as[i]表示小于b[i]的数的个数
- int cs[N]; // cs[i]表示大于b[i]的数的个数
- int cnt[N], s[N];
- int n;
-
- int main()
- {
- scanf("%d",&n);
-
- for(int i = 0;i < n;i ++) scanf("%d",&a[i]), a[i] ++;
- for(int i = 0;i < n;i ++) scanf("%d",&b[i]), b[i] ++;
- for(int i = 0;i < n;i ++) scanf("%d",&c[i]), c[i] ++;
-
- for(int i = 0;i < n;i ++) cnt[a[i]] ++;
- for(int i = 1;i < N;i ++) s[i] = s[i - 1] + cnt[i];
-
- // 求as
- for(int i = 0;i < n;i ++) as[i] = s[b[i] - 1];
-
- memset(cnt, 0, sizeof cnt);
- memset(s, 0, sizeof s);
- // 求cs
- for(int i = 0;i < n;i ++) cnt[c[i]] ++;
- for(int i = 1;i < N;i ++) s[i] = s[i - 1] + cnt[i];
- for(int i = 0;i < n;i ++) cs[i] = s[N - 1] - s[b[i]];
-
- LL res = 0;
- for(int i = 0;i < n;i ++) res += (LL) as[i] * cs[i];
-
- cout << res << endl;
- return 0;
- }