我们有N头牛,需要两两之间相互通讯,其中每头牛对应一个坐标x和一个听力v,设第i头牛的听力为v(i),坐标为x(i)(1<=x<=20000),
已知牛i和牛j相互通讯需要的音量为 max(v(i),v(j))*|x(i)-x(j)|,求出N(N-1)对通讯的音量的总和。
我们首先看下,牛i和牛j相互通讯需要的音量为 max(v(i),v(j))*|x(i)-x(j)|,那么如果我们对所有的牛根据v来排序,保证对于 i < j ,v[i] <= v[j],之后从第一个开始循环,计算 i 和 i 前面的所有牛的坐标差之和 * v[i]即可。
然后我们可以记录2个树状数组,其中一个bit用来存坐标,另一个用bitCnt来计数,每一次循环执行如下的事情
1)计算bit的前 x[i]项和leftSum,和 前 262144 的和allSum(所有元素)
2)计算bitCnt前 x[i]项和leftCnt,和 前 262144 的和allCnt(所有个数)
3)其中leftCnt就是牛i左边的牛的坐标和,例如3头牛 1 2 5,那么它左边的坐标和为1+2=3,然后leftCnt就是牛i左边牛的数量,即2,那么 i 和左边的坐标差=(5-1)+(5-1)=5 * 2 - (2 + 1)
所以左边的牛的音量 = (x[i] * leftCnt - leftSum )*v[i]
然后allCnt-leftCnt就是右边的数量rightCnt,allSum-leftSum就是右边的数量rightSum
同理右边的牛的音量 = ( rightSum - x[i]*rightCnt )*v[i]
之后把左右两边音量的和加到ans里即可
4)将bit的x[i]位+x[i],将bitCnt的x[i]位+1
bit[ x[ i ] ]+=x[i]
bitCnt[ x[ i ] ]+=1
同步更新两棵树的父节点...
(备注:本题目中坐标乘以数量然后不断求和的过程中 n * (n-1) * 20000可能会大于int32,注意给结果开long long)
- #include
- #include
- using namespace std;
- typedef long long ll;
- typedef pair<int, int> P;
- P num[262150];
- int bit[262150], n_, n, bitCnt[262150];
- ll ans = 0LL;
- void input()
- {
- scanf("%d", &n_);
- for (int i = 1; i <= n_; i++)
- {
- scanf("%d%d", &num[i].first, &num[i].second);
- }
- sort(num + 1, num + (1 + n_));
- }
- void init()
- {
- n = 262144;
- for (int i = 0; i <= n; i++)
- {
- bit[i] = 0;
- bitCnt[i] = 0;
- }
- }
- void update(int r, int v)
- {
- if (r <= 0)
- {
- return;
- }
- for (int i = r; i <= n; i = i + (i & (-i)))
- {
- bit[i] = bit[i] + v;
- }
- }
- void updateCnt(int r, int v)
- {
- if (r <= 0)
- {
- return;
- }
- for (int i = r; i <= n; i = i + (i & (-i)))
- {
- bitCnt[i] = bitCnt[i] + v;
- }
- }
- int query(int r)
- {
- int sum = 0;
- for (int i = r; i > 0; i = i - (i & (-i)))
- {
- sum = sum + bit[i];
- }
- return sum;
- }
- int queryCnt(int r)
- {
- int sum = 0;
- for (int i = r; i > 0; i = i - (i & (-i)))
- {
- sum = sum + bitCnt[i];
- }
- return sum;
- }
- void solve()
- {
- for (int i = 1; i <= n_; i++)
- {
- int leftSum = query(num[i].second);
- int allSum = query(n);
- int leftCnt = queryCnt(num[i].second);
- int allCnt = queryCnt(n);
- ans = ans + (((ll)((leftCnt * num[i].second) - leftSum)) * ((ll)num[i].first));
- ans = ans + (((ll)((allSum - leftSum) - ((allCnt - leftCnt) * num[i].second))) * ((ll)num[i].first));
- update(num[i].second, num[i].second);
- updateCnt(num[i].second, 1);
- }
- }
- int main()
- {
- input();
- init();
- solve();
- printf("%lld\n", ans);
- return 0;
- }