[@Toc](leetcode 6103 — 从树中删除边的最小分数)

如果我们删除两条边,那么可以分为两种情况讨论:
1、四个节点在从根节点到某个叶节点的一条路径上;

2、四个节点分别在从根节点到两个叶节点的路径上。

我们不妨设两条边的第一条边的节点为 min1,max1(
m
i
n
1
深
度
<
m
a
x
1
深
度
min1深度 < max1深度
min1深度<max1深度);第二条边的节点为 min2,max2(
m
i
n
2
深
度
<
m
a
x
2
深
度
min2深度 < max2深度
min2深度<max2深度)。
针对第一种情况,假设 m a x 1 深 度 < m a x 2 深 度 max1深度 < max2深度 max1深度<max2深度,不难发现三个连通域分别为 max2 的子树,max1 的子树中去除 max2 的子树以及剩余部分。
针对第二种情况,三个连通域分别 max2 的子树,max1 的子树以及剩余部分。
那么由于求的是异或,只要我们知道任意两个连通域的异或结果和总的异或结果就可以知道第三个连通域的异或。而上述异或可以通过所有节点的子树异或得到,即一次 dfs。
这道题困扰我的就是如何判断要删除的两条边属于情况1还是情况2。其实这个问题类似寻找两个节点的公共祖先节点。但是如果针对每种边的组合都进行一次查找,复杂度会比较高。看了网上的解答,我学习到了可以借助时间戳的概念来实现。
针对每个节点,记录一个 in,out 数组,分别用于记录访问该节点的开始时间和结束时间。在 dfs 的过程中,每当访问某个节点时,将当前时间赋值给 in 数组并递增当前时间;每当结束某个节点的 dfs 时,将当前时间赋值给 out 数组。
注意这里的 in,out 数组同样可以用来判断同一分支上的相对深度。
class Solution {
public:
int minimumScore(vector<int>& nums, vector<vector<int>>& edges) {
int n = nums.size(), e = edges.size();
vector<vector<int>> edgesWithNode(n);
for (int i = 0; i < e; ++i) {
edgesWithNode[edges[i][0]].push_back(edges[i][1]);
edgesWithNode[edges[i][1]].push_back(edges[i][0]);
}
vector<int> xorsChildren(n, -1);
vector<bool> visited(n);
vector<int> in(n), out(n);
int clock = 0;
function<void(int)> dfs = [&](int index) {
visited[index] = true;
in[index] = clock++;
xorsChildren[index] = nums[index];
for (auto next : edgesWithNode[index]) {
if (!visited[next]) {
dfs(next);
xorsChildren[index] = xorsChildren[next] ^ xorsChildren[index];
}
}
out[index] = clock;
};
dfs(0);
int total = xorsChildren[0];
int ret = INT_MAX;
for (int i = 0; i < e; ++i) {
for (int j = i + 1; j < e; ++j) {
vector<int> curXors(3);
int max1 = in[edges[i][0]] > in[edges[i][1]] ? edges[i][0] : edges[i][1];
int max2 = in[edges[j][0]] > in[edges[j][1]] ? edges[j][0] : edges[j][1];
int min1 = max1 == edges[i][0] ? edges[i][1] : edges[i][0];
int min2 = max2 == edges[j][0] ? edges[j][1] : edges[j][0];
int low = in[max1] > in[max2] ? max1 : max2;
int high = low == max1 ? max2 : max1;
bool sameLine = in[high] <= in[low] && out[low] <= out[high];
if (!sameLine) {
curXors[0] = xorsChildren[max1];
curXors[1] = xorsChildren[max2];
}
else {
if (in[max1] >= in[max2]) {
curXors[0] = xorsChildren[max1];
curXors[1] = xorsChildren[max2];
curXors[1] ^= curXors[0];
}
else {
curXors[0] = xorsChildren[max1];
curXors[1] = xorsChildren[max2];
curXors[0] ^= curXors[1];
}
}
curXors[2] = total ^ curXors[0] ^ curXors[1];
int score = max({ curXors[0], curXors[1], curXors[2] }) - min({ curXors[0], curXors[1], curXors[2] });
ret = min(ret, score);
}
}
return ret;
}
};
