Tree DP的典型例题,值得一写。
树状DP的核心思想是,将每个节点及其构成的子树看做一个整体,求整体子树的状态值。此时父亲节点子树
T
r
e
e
(
u
)
Tree(u)
Tree(u)的状态值可以通过DFS由子树
T
r
e
e
(
v
)
Tree(v)
Tree(v)的状态值推出,而当需要求子树节点的答案时,可以将子树的根节点看做为整体树的根节点,通过简单的递推计算得到新的状态值。
此题需要求解所有距离之和,因此隔离出每棵子树,设该子树的根节点
v
v
v到子树中每个节点的距离之和为
d
p
(
v
)
dp(v)
dp(v)。经过计算可得,子树1中的
d
p
(
1
)
=
3
dp(1)=3
dp(1)=3,而1234 四个点到达0的距离之和需要对每个点加上0-1之间的边距离1,实质就是加上子树中节点的数量
s
i
z
e
(
v
)
size(v)
size(v)。即子树
T
r
e
e
(
1
)
Tree(1)
Tree(1)对于树
T
r
e
e
(
0
)
Tree(0)
Tree(0)的贡献为3+4。同理,子树
T
r
e
e
(
5
)
Tree(5)
Tree(5)对于树
T
r
e
e
(
0
)
Tree(0)
Tree(0)的贡献为0+1。最后将所有子树的dp值累计即可得到根节点的dp值。
接下来需要求 v v v作为根节点时的dp状态值。暴力的做法是将 v v v作为根节点,重新上述的过程一遍。但是该方法将造成大量重复计算,未能有效利用之前计算得到的结果。
例如,计算子树 T r e e ( 1 ) Tree(1) Tree(1)的状态值,将1提到根节点处,对其状态值进行重新计算。但是之前1作为子树已经计算得到了结果,能否对此结果加以利用?答案是肯定的。
如图所示,上次遍历时
d
p
(
1
)
,
s
i
z
e
(
1
)
dp(1),size(1)
dp(1),size(1)经过计算,得到是以1为根节点虚线圈中子树的状态值。而原来的父节点
u
u
u此时已经降为
v
v
v的子节点,因此
d
p
(
v
)
dp(v)
dp(v)需要增加
u
u
u子树的那部分状态值,而
d
p
(
u
)
dp(u)
dp(u)需要减掉和
v
v
v相连通的那部分。
可以得到 d p ( 0 ) = d p ( 0 ) − d p ( 1 ) − s i z e ( 1 ) , s i z e ( 0 ) = s i z e ( 0 ) − s i z e ( 1 ) dp(0)=dp(0)-dp(1)-size(1), size(0)=size(0)-size(1) dp(0)=dp(0)−dp(1)−size(1),size(0)=size(0)−size(1),而计算完毕 u u u的新值后,再对 v v v进行更新: d p ( 1 ) = d p ( 1 ) + d p ( 0 ) + s i z e ( 0 ) , s i z e ( 1 ) = s i z e ( 1 ) + s i z e ( 0 ) dp(1)=dp(1)+dp(0)+size(0), size(1)=size(1)+size(0) dp(1)=dp(1)+dp(0)+size(0),size(1)=size(1)+size(0)
用LC的官方题解作为结束语:
时间复杂度:O(n),其中 n 是树中的节点个数。我们只需要遍历整棵树两次即可得到答案,其中每个节点被访问两次,因此时间复杂度为 O(2n)=O(n)。
空间复杂度:O(n)。我们需要线性的空间存图,n 个节点的树包含 n−1 条边,数组dp 和sz 的长度均为 n。
typedef long long ll;
typedef pair<int, int> pii;
typedef vector<int> vi;
typedef vector<string> vs;
typedef unsigned int ui;
class Solution {
vector<vi> w;
int n;
vi vis; // visit
vi ans;
vi sz, dp;
void go(int u) {
vis[u] = 1;
sz[u] = 1;
dp[u] = 0;
for (int v : w[u]) {
if (vis[v])
continue;
go(v);
sz[u] += sz[v];
dp[u] += dp[v] + sz[v];
}
}
void go2(int u) {
vis[u] = 1;
for (int v : w[u]) {
if (vis[v]) continue;
int l_dpu = dp[u], l_szu = sz[u], l_dpv = dp[v], l_szv = sz[v];
dp[u] -= dp[v] + sz[v];
sz[u] -= sz[v];
dp[v] += dp[u] + sz[u];
sz[v] += sz[u];
ans[v] = dp[v];
go2(v);
dp[u] = l_dpu, dp[v] = l_dpv, sz[u] = l_szu, sz[v] = l_szv;
}
}
public:
vector<int> sumOfDistancesInTree(int n, vector<vector<int>>& edges) {
this->n = n;
ans.resize(n);
dp.resize(n);
sz.resize(n);
vis.resize(n);
vi temp;
w.resize(n, temp);
for (vi e : edges) {
int u = e[0], v = e[1];
w[u].push_back(v);
w[v].push_back(u);
}
go(0); // connected tree
ans[0] = dp[0];
for (int i = 0; i < n; ++i)
vis[i] = 0;
go2(0);
return ans;
}
};