一个正整数 x 的二进制表示为 ,其中等于1的位是
则 x 可以被二进制表示为
不妨设 ,进一步的,区间[1, x] 可以分成 O(logx) 个小区间
这些小区间的共同特点是:若区间结尾为R,则区间长度就是等于R的“二进制分解”下最小的2的次幂,即 lowbit(R).
例如:,区间 [1, 7] 可以分成 [1, 4] [5, 6] [7, 7]
长度分别是 lowbit(4) = 4, lowbit(6) = 2, lowbit(7) = 1
〔manim | 算法 | 数据结构〕 完全理解并深入应用树状数组 | 支持多种动态维护区间操作_哔哩哔哩_bilibili
树状数组(Binary Indexed Tree)是一种 基于上述思想的数据结构,其基本用途就是维护序列的前缀和。对于给定的序列 a ,我们建立一个数组 c, 其中 c[x] 保存 a 的区间
[x - lowbit(x) + 1, x] 中所有数的和,
黑色数组代表原来的数组(下面用A[i]代替),红色结构代表我们的树状数组(下面用C[i]代替),发现没有,每个位置只有一个方框,令每个位置存的就是子节点的值的和,则有
- C[1] = A[1];
- C[2] = A[1] + A[2];
- C[3] = A[3];
- C[4] = A[1] + A[2] + A[3] + A[4];
- C[5] = A[5];
- C[6] = A[5] + A[6];
- C[7] = A[7];
- C[8] = A[1] + A[2] + A[3] + A[4] + A[5] + A[6] + A[7] + A[8];
可以发现,这颗树是有规律的
C[i] = A[i - 2k+1] + A[i - 2k+2] + ... + A[i]; //k为i的二进制中从最低位到高位连续零的长度
例如i = 8(1000)时候,k = 3,可自行验证。
这个怎么实现求和呢,比如我们要找前7项和,那么应该是SUM = C[7] + C[6] + C[4];
而根据上面的式子,容易得出
其实树状数组就是一个二进制上面的应用。
输入样例:
5 1 5 3 2 4输出样例:
3 4
- #include
- #include
- #include
- #include
-
- using namespace std;
-
- const int N = 200010;
-
- typedef long long LL;
-
- int n;
- int a[N];
- int tr[N];
- int greaterr[N], lower[N];
-
- //返回非负整数x在二进制表示下最低位1及其后面的0构成的数值
- int lowbit(int x)
- {
- return x & -x;
- }
-
- //将序列中第x个数加上k。
- void add(int x, int c)
- {
- for(int i = x; i <= n; i += lowbit(i)) tr[i] += c;
- }
-
- //查询序列前x个数的和
- int sum(int x)
- {
- int res = 0;
- for(int i = x; i; i -= lowbit(i)) res += tr[i];
- return res;
- }
-
- int main()
- {
- cin >> n;
-
- for(int i = 1; i <= n; i ++ ) scanf("%d",&a[i]);
-
- //从左向右,依次统计每个位置左边比第i个数y小的数的个数、以及大的数的个数
- for(int i = 1; i <= n; i ++ )
- {
- int y = a[i];
-
- //在前面已加入树状数组的所有数中统计在区间[1, y - 1]的数字的出现次数
- greaterr[i] = sum(n) - sum(y);
-
- //在前面已加入树状数组的所有数中统计在区间[y + 1, n]的数字的出现次数
- lower[i] = sum(y - 1);
-
- //将y加入树状数组,即数字y出现1次
- add(y,1);
- }
-
- //清空树状数组,从右往左统计每个位置右边比第i个数y小的数的个数、以及大的数的个数
- memset(tr, 0, sizeof tr);
-
- LL res1 = 0, res2 = 0;
- for(int i = n; i; i --)
- {
- int y = a[i];
- res1 += greaterr[i] * (LL)(sum(n) - sum(y));
- res2 += lower[i] * (LL)(sum(y - 1));
-
- //将y加入树状数组,即数字y出现1次
- add(y,1);
- }
-
- cout << res1 << " " << res2;
- return 0;
- }
输入样例:
10 5 1 2 3 4 5 6 7 8 9 10 Q 4 Q 1 Q 2 C 1 6 3 Q 2输出样例:
4 1 2 5
树状数组 + 差分
#include #include #include #include using namespace std; typedef long long LL; const int N = 100010; int n, m; int a[N]; LL tr[N]; int lowbit(int x) { return x & -x; } void add(int x, int t) { for(int i = x; i <= n; i += lowbit(i)) tr[i] += t; } LL sum(int x) { LL res = 0; for(int i = x; i; i -= lowbit(i)) res += tr[i]; return res; } int main() { cin >> n >> m; for(int i = 1; i <= n; i ++ ) scanf("%d", &a[i]); for(int i = 1; i <= n; i ++ ) add(i, a[i] - a[i - 1]); while(m -- ) { char op[2]; int l, r, d; scanf("%s%d", op, &l); if(*op == 'C') { scanf("%d%d", &r, &d); add(l, d), add(r + 1, -d); } else { printf("%lld\n", sum(l)); } } return 0; }
输入样例:
10 5 1 2 3 4 5 6 7 8 9 10 Q 4 4 Q 1 10 Q 2 4 C 3 6 3 Q 2 4输出样例:
4 55 9 15
- #include
- #include
- #include
- #include
-
- using namespace std;
-
- typedef long long LL;
-
- const int N = 100010;
-
- int n,m;
- int a[N];
- LL tr1[N]; //维护差分数组b[i]的前缀和
- LL tr2[N]; //维护b[i] * i 的前缀和
-
- int lowbit(int x)
- {
- return x & -x;
- }
-
- void add(LL tr[], int x, LL c)
- {
- for(int i = x; i <= n; i += lowbit(i)) tr[i] += c;
- }
-
- LL sum(LL tr[], int x)
- {
- LL res = 0;
- for(int i = x; i; i -= lowbit(i)) res += tr[i];
- return res;
- }
-
- LL prefix_sum(int x)
- {
- return sum(tr1, x) * (x + 1) - sum(tr2, x);
- }
-
- int main()
- {
- cin >> n >> m;
- for(int i = 1; i <= n; i ++ ) scanf("%d",&a[i]);
- for(int i = 1; i <= n; i ++ )
- {
- int b = a[i] - a[i - 1];
- add(tr1, i, b);
- add(tr2, i, (LL)i * b);
- }
-
- while(m --)
- {
- char op[2];
- int l,r,d;
- scanf("%s%d%d", op, &l, &r);
- if(*op == 'Q')
- {
- printf("%lld\n",prefix_sum(r) - prefix_sum(l - 1));
- }
- else
- {
- scanf("%d",&d);
- add(tr1, l, d), add(tr1, r + 1, -d);
- add(tr2, l, l * d), add(tr2, r + 1, (r + 1) * -d);
- }
- }
- return 0;
- }
输入样例:
5 1 2 1 0输出样例:
2 4 5 3 1
- // 找到一个最小的x是sum(x) = k
- #include
- #include
- #include
-
- using namespace std;
-
- const int N = 100010;
-
- int n;
- int h[N];
- int ans[N];
- int tr[N];
-
- int lowbit(int x)
- {
- return x & -x;
- }
-
- void add(int x, int c)
- {
- for(int i = x; i <= n; i += lowbit(i)) tr[i] += c;
- }
-
- int sum(int x)
- {
- int res = 0;
- for(int i = x; i; i -= lowbit(i)) res += tr[i];
- return res;
- }
-
- int main()
- {
- cin >> n;
- for(int i = 2; i <= n; i ++ ) scanf("%d", &h[i]);
-
- for(int i = 1; i <= n; i ++ ) tr[i] = lowbit(i);
-
- for(int i = n; i; i -- )
- {
- int k = h[i] + 1;
- int l = 1, r = n;
- while(l < r)
- {
- int mid = l + r >> 1;
- if(sum(mid) >= k) r = mid;
- else l = mid + 1;
- }
- ans[i] = r;
- add(r, -1);
- }
-
- for(int i = 1; i <= n; i ++) printf("%d\n", ans[i]);
-
- return 0;
- }