给定一棵
n
n
n个节点的树,每个节点有
a
[
i
]
a[i]
a[i]个蝴蝶,如果蝴蝶被惊动了会在
t
[
i
]
t[i]
t[i]时间后飞走,被惊动的条件是你走到该节点相邻的节点。最开始
0
0
0时刻你在
1
1
1号节点,走一条边需要时间为
1
1
1,你有无限时间抓蝴蝶,问最多能抓到多少蝴蝶。
1
<
=
n
<
=
1
0
5
1<=n<=10^5
1<=n<=105
1
<
=
a
i
<
=
1
0
9
1<=ai<=10^9
1<=ai<=109
1
<
=
t
i
<
=
3
1<=ti<=3
1<=ti<=3
首先可以从题目中得到,当到达一个节点时,该节点的子节点上的蝴蝶全都会被惊动,而从一个子节点走到另一个子节点花费的时间为 3 3 3。所以当选择抓一个子节点的蝴蝶后,还可以选择一个时间 t [ i ] = 3 t[i] = 3 t[i]=3的子节点抓蝴蝶,其他子节点的蝴蝶都会飞走。
给出
D
P
DP
DP方程定义:
f
[
u
]
[
1
]
[
1
]
f[u][1][1]
f[u][1][1]:该点的蝴蝶抓,且该点的子节点至少抓一个
f
[
u
]
[
1
]
[
0
]
f[u][1][0]
f[u][1][0]:该点的蝴蝶抓,但该点的子节点一个都不抓
f
[
u
]
[
0
]
[
1
]
f[u][0][1]
f[u][0][1]:该点的蝴蝶不抓,但该点的子节点的至少抓一个节点的蝴蝶
现在考虑:
对于父节点
u
u
u的子节点
v
i
vi
vi ,我们至少可以选择一个节点得到
a
[
v
]
a[v]
a[v]的蝴蝶。其他的子节点
v
i
vi
vi的子节点一定至少可以有一个可以抓到蝴蝶。
若
t
[
v
i
]
=
3
t[vi] = 3
t[vi]=3 则可以到其他分支
v
j
vj
vj上再取一点再跑回该节点
v
i
vi
vi来, 但是取的另一个节点
v
j
vj
vj的子节点的蝴蝶一定抓不到了。
所以我们对于每个节点
u
u
u的子节点有两种取法
设一个节点有
m
m
m个子节点
1
1
1.抓一个子节点舍弃其他点的蝴蝶
f
[
v
i
]
[
1
]
[
1
]
+
∑
k
=
1
m
f
[
v
k
]
[
0
]
[
1
]
(
i
!
=
k
)
f[vi][1][1] + \sum_{k=1}^{m}f[vk][0][1](i != k)
f[vi][1][1]+∑k=1mf[vk][0][1](i!=k)
2
2
2.当我们选择抓的蝴蝶时间为
t
[
v
i
]
=
3
t[vi] = 3
t[vi]=3 我们可以再抓一个点的蝴蝶
f
[
v
i
]
[
1
]
[
1
]
+
f
[
v
j
]
[
1
]
[
0
]
+
∑
k
=
1
m
f
[
v
k
]
[
0
]
[
1
]
(
i
!
=
j
!
=
k
)
f[vi][1][1] + f[vj][1][0] + \sum_{k = 1}^{m}f[vk][0][1](i != j != k)
f[vi][1][1]+f[vj][1][0]+∑k=1mf[vk][0][1](i!=j!=k)
我们可以枚举要抓的蝴蝶
v
i
vi
vi,对于是否要再抓一次蝴蝶取最大值即可
#include
#include
#include
using namespace std;
#define ll long long
const int N = 1e5 + 10;
typedef pair<int,int>PII;
int a[N],t[N],n;
vector<int>g[N];
ll f[N][2][2];
/*
f[u][1][1]:该点的值取,且该点子节点的至少取一个
f[u][1][0]:该点的值取,但该点子节点的值不取
f[u][0][1]:该点的值不取,该点子节点的至少取一个
*/
struct node
{
int id;//节点编号
ll res;//值
bool operator < (const node &A)const{
return res > A.res;
}
};
void dfs(int u,int fa)
{
f[u][0][1] = 0;
f[u][1][1] = f[u][1][0] = a[u];//初始化
vector<node>tmp;
ll sum = 0,ans = 0;
//先全部取第一种情况 即只抓一个节点的蝴蝶,其余节点的蝴蝶舍去
for(int v : g[u])
{
if(v == fa) continue ;
dfs(v, u);
sum += f[v][0][1];//记录下所有子节点都可以取子节点的和
tmp.push_back({v, f[v][1][0] - f[v][0][1]});//第二种情况,选一个节点v先取,再去最后要取的节点
}
sort(tmp.begin(), tmp.end());
int siz = tmp.size();
for(int v : g[u])//枚举一定取的节点即f[v][1][1],对于两种情况取最大值
{
if(v == fa) continue ;
ans = max(ans, sum + f[v][1][1] - f[v][0][1]);
if(t[v] == 3)
{
//注意第二种情况抓的两个子节点的蝴蝶,节点编号不能重叠
if(tmp[0].id != v) ans = max(ans, sum + f[v][1][1] - f[v][0][1] + tmp[0].res);
else if(siz >= 2) ans = max(ans, sum + f[v][1][1] - f[v][0][1] + tmp[1].res);
}
}
f[u][1][1] += ans;
f[u][0][1] += ans;
f[u][1][0] += sum;
return ;
}
void solve()
{
scanf("%d",&n);
for(int i = 1; i <= n; i ++){
scanf("%d",&a[i]);
}
for(int i = 1; i <= n; i ++){
scanf("%d",&t[i]);
g[i].clear();
}
for(int i = 1; i < n; i ++){
int u,v;
scanf("%d%d",&u,&v);
g[u].push_back(v);
g[v].push_back(u);
}
dfs(1, 0);
printf("%lld\n",f[1][1][1]);
return ;
}
int main()
{
int T;
scanf("%d",&T);
while(T --) solve();
return 0;
}