一棵树有 n 个节点,编号为 1~n,每个节点都有一个权值 w 。完成以下操作。
① CHANGE u t ,把节点 u 的权值修改为 t
② QMAX u v ,询问从节点 u 到节点 v 路径上节点的最大权值
③ QSUM u v ,询问从节点 u 到节点 v 路径上节点的权值和
注意:从节点 u 到节点 v 路径上的节点包括 u 和 v 自身。
第 1 行包含一个整数 n ,表示节点的个数。接下来的 n -1 行,每行都包含两个整数 a 和 b ,表示在节点 a 和节点 b 之间有一条边相连。接下来的 n 行,第 i 行的整数 wi 表示节点 i 的权值。接下来的一行包含一个整数 q ,表示操作的总数。最后有 q 行,每行都表示一种操作,操作形式如上所述 。 其中,1≤n≤30000,0≤q≤200000,保证操作中每个节点的权值 w 都为 -30000~ 30000。
对每个 QMAX 或者 QSUM 的操作,都单行输出一个整数,表示要求的结果。
4
1 2
2 3
4 1
4 2 1 3
12
QMAX 3 4
QMAX 3 3
QMAX 3 2
QMAX 2 3
QSUM 3 4
QSUM 2 1
CHANGE 1 5
QMAX 3 4
CHANGE 3 6
QMAX 3 4
QMAX 2 4
QSUM 3 4
4
1
2
2
10
6
5
6
5
16
本问题包括树上点更新、区间最值、区间和值查询。可以用树链剖分将树形结构线性化,然后用线段树进行点更新、区间最值、区间和值查询。解决方案:树链剖分+线段树。
1 第 1 次深度优先遍历求 dep、fa、size、son,第 2 次深度优先遍历求 top、id、rev;
2 创建线段树;
3 点更新,u 对应的下标 i = id[u],在线段树中将该下标的值更新为 val;
4 区间查询,求 u 、v 之间的最值与和值。若 u 、v 不在同一条重链上,则一边查询,一边向同一条重链靠拢;若 u 、v 在同一条重链上,则根据节点的下标在线段树中进行区间查询。
- package com.platform.modules.alg.alglib.hysbz1036;
-
- public class Hysbz1036 {
- public String output = "";
- private int maxn = 30005;
- int n, m; // n 个结点,m 个查询
- int head[] = new int[maxn]; // 头结点
- int cnt = 0, total = 0;
- int fa[] = new int[maxn]; // 父亲
- int dep[] = new int[maxn]; // 深度
- int size[] = new int[maxn]; // 子树结点总数
- int son[] = new int[maxn]; // 重儿子
- int top[] = new int[maxn]; // 所在重链顶端结点
- int w[] = new int[maxn]; // 权值
- int id[] = new int[maxn];
- int rev[] = new int[maxn]; // u 对应的 dfs 序下标,下标对于的 u
- int Max, Sum;
- edge e[] = new edge[maxn << 1];
- // 树结点存储数组
- node tree[] = new node[maxn * 4];
-
- void add(int u, int v) {
- e[++cnt].to = v;
- e[cnt].next = head[u];
- head[u] = cnt;
- }
-
- // 求 dep,fa,size,son
- void dfs1(int u, int f) {
- size[u] = 1;
- for (int i = head[u]; i > 0; i = e[i].next) {
- int v = e[i].to;
- if (v == f) // 父节点
- continue;
- dep[v] = dep[u] + 1; // 深度
- fa[v] = u;
- dfs1(v, u);
- size[u] += size[v];
- if (size[v] > size[son[u]])
- son[u] = v;
- }
- }
-
- // 求 top,id,rev
- void dfs2(int u, int t) {
- top[u] = t;
- id[u] = ++total; // u 对应的 dfs 序下标
- rev[total] = u; // dfs 序下标对应的结点 u
- if (son[u] == 0)
- return;
- dfs2(son[u], t); // 沿着重儿子 dfs
- for (int i = head[u]; i > 0; i = e[i].next) {
- int v = e[i].to;
- if (v != fa[u] && v != son[u])
- dfs2(v, v);
- }
- }
-
- // 创建线段树,k 表示存储下标,区间 [l,r]
- void build(int k, int l, int r) {
- tree[k].l = l;
- tree[k].r = r;
- if (l == r) {
- tree[k].mx = tree[k].sum = w[rev[l]];
- return;
- }
- int mid, lc, rc;
- mid = (l + r) / 2; // 划分点
- lc = k * 2; // 左孩子存储下标
- rc = k * 2 + 1; // 右孩子存储下标
- build(lc, l, mid);
- build(rc, mid + 1, r);
- tree[k].mx = Math.max(tree[lc].mx, tree[rc].mx); // 结点的最大值等于左右孩子最值的最大值
- tree[k].sum = tree[lc].sum + tree[rc].sum; // 结点的和值等于左右子树和值
- }
-
- // 求区间[l..r]的最值、和值
- void query(int k, int l, int r) {
- if (tree[k].l >= l && tree[k].r <= r) { // 找到该区间
- Max = Math.max(Max, tree[k].mx);
- Sum += tree[k].sum;
- return;
- }
- int mid, lc, rc;
- mid = (tree[k].l + tree[k].r) / 2;//划分点
- lc = k * 2; //左孩子存储下标
- rc = k * 2 + 1;//右孩子存储下标
- if (l <= mid)
- query(lc, l, r);//到左子树查询
- if (r > mid)
- query(rc, l, r);//到右子树查询
- }
-
- void ask(int u, int v) {//求u,v之间的最值或和值
- while (top[u] != top[v]) {//不在同一条重链上
- if (dep[top[u]] < dep[top[v]]) {
- int temp = u;
- u = v;
- v = temp;
- }
- query(1, id[top[u]], id[u]); // u 顶端结点和 u之间
- u = fa[top[u]];
- }
- if (dep[u] > dep[v]) { // 在同一条重链上
- int temp = u;
- u = v;
- v = temp;
- }
- query(1, id[u], id[v]);
- }
-
- void update(int k, int i, int val) {//u对应的下标i,将其值修改更新为val
- if (tree[k].l == tree[k].r && tree[k].l == i) {//找到i
- tree[k].mx = tree[k].sum = val;
- return;
- }
- int mid, lc, rc;
- mid = (tree[k].l + tree[k].r) / 2;//划分点
- lc = k * 2; //左孩子存储下标
- rc = k * 2 + 1;//右孩子存储下标
- if (i <= mid)
- update(lc, i, val);//到左子树修改更新
- else
- update(rc, i, val);//到右子树修改更新
- tree[k].mx = Math.max(tree[lc].mx, tree[rc].mx);//返回时修改更新最值
- tree[k].sum = tree[lc].sum + tree[rc].sum;//返回时修改更新和值
- }
-
- public Hysbz1036() {
- for (int i = 0; i < e.length; i++) {
- e[i] = new edge();
- }
- for (int i = 0; i < tree.length; i++) {
- tree[i] = new node();
- }
- }
-
- public String cal(String input) {
- int x, y;
- String str;
-
- String[] line = input.split("\n");
- String[] words = line[0].split(" ");
- n = Integer.parseInt(words[0]);
- for (int i = 1; i < n; i++) {
- String[] num = line[i].split(" ");
- x = Integer.parseInt(num[0]);
- y = Integer.parseInt(num[1]);
- add(x, y);
- add(y, x);
- }
- String[] wage = line[n].split(" ");
- for (int i = 1; i <= n; i++)
- w[i] = Integer.parseInt(wage[i-1]);
- dep[1] = 1;
- dfs1(1, 0);
- dfs2(1, 1);
- build(1, 1, total);//创建线段树
- m = Integer.parseInt(line[n + 1]);
- for (int i = n + 2; i <= m + n + 1; i++) {
- String[] query = line[i].split(" ");
- str = query[0];
- x = Integer.parseInt(query[1]);
- y = Integer.parseInt(query[2]);
- if (str.charAt(0) == 'C')
- update(1, id[x], y);
- else {
- Sum = 0;
- Max = -0x3f3f3f3f;
- ask(x, y);
- if (str.charAt(1) == 'M')
- output += Max + "\n";
- else
- output += Sum + "\n";
- }
- }
- return output;
- }
- }
-
- class edge {
- int to, next;
- }
-
- // 结点
- class node {
- int l, r, mx, sum; // l,r 表示区间左右端点,mx 表示区间 [l,r] 的最值
- }
