题目链接
Unique Occurrences
题意
给定一棵含有
n
n
n个结点的树,树上的每条边都有一个权值。
f
(
v
,
u
)
f(v,u)
f(v,u)表示
v
v
v到
u
u
u的简单路径上边权只出现一次的边权的个数。求
∑
f
(
v
,
u
)
(
1
≤
v
<
u
≤
n
)
\sum f(v,u)(1 \leq v < u \leq n)
∑f(v,u)(1≤v<u≤n)。
分析
思路一
可以发现边权之间互不影响,可以按边权讨论。对于边权
w
w
w,现在要解决的问题是树上只经过边权为
w
w
w的边一次的简单路径的数量,此问题可以用树形DP求解。
f
[
x
]
[
0
]
f[x][0]
f[x][0]表示从
x
x
x开始向下有
f
[
x
]
[
0
]
f[x][0]
f[x][0]条边不经过边权为
w
w
w的边,
f
[
x
]
[
1
]
f[x][1]
f[x][1]表示从
x
x
x开始向下有
f
[
x
]
[
1
]
f[x][1]
f[x][1]条路径只经过边权为
w
w
w的边一次,状态转移方程如下,其中
W
(
x
,
y
)
W(x,y)
W(x,y) 表示
x
x
x 到
y
y
y 的边权,
x
x
x 是
y
y
y 的父结点。
{
f
[
x
]
[
0
]
=
∑
(
f
[
y
]
[
0
]
+
1
)
,
W
(
x
,
y
)
≠
w
f
[
x
]
[
1
]
+
=
∑
(
f
[
y
]
[
0
]
+
1
)
,
W
(
x
,
y
)
=
w
f
[
x
]
[
1
]
+
=
∑
f
[
y
]
[
1
]
,
W
(
x
,
y
)
≠
w
\left\{
在DP的同时统计所有符合要求的边的数量,每次DP的时间复杂度是
O
(
n
)
O(n)
O(n),这样总的时间复杂度是
O
(
n
2
)
O(n^2)
O(n2)。考虑进行优化,对于一种边权,涉及到的点是有限的,不用建出完整的树,可以只对边权为
w
w
w的边包含的顶点建立虚树,
n
n
n棵树的总点数是
O
(
n
)
O(n)
O(n)的,建虚树需要用到倍增求LCA,总的时间复杂度为
O
(
n
l
o
g
(
n
)
)
O(nlog(n))
O(nlog(n))。
思路二
思路一是按边权讨论,保留边权为
w
w
w的边,其实也可以只删去边权为
w
w
w的边。当删去边权为
w
w
w的边时,原树分成了若干个连通块,假设连通块之间以边权为
w
w
w的虚边连接,那么答案就是所有虚边两侧连通块大小乘积再求和。具体实现时有两种方法,一种是分治+可撤销并查集,另一种是Link Cut Tree,前者时间复杂度是
O
(
n
l
o
g
(
n
)
)
O(nlog(n))
O(nlog(n)),后者时间复杂度是
O
(
n
l
o
g
(
n
)
)
O(nlog(n))
O(nlog(n))。
AC代码
虚树
typedef long long ll;
const int N=5e5+10;
const int M=2*N;
int head[N],e[M],ne[M],w[M],tot;
int a[N],d[N],sz[N],dfn[N],stk[N],f[N][21];
vector<pair<int,int>> vec[N];
map<int,int> col[N];
map<int,set<int>> mp;
int n,t,cnt,num,top;
ll g[N][2];
ll ans;
void add(int x,int y,int z)
{
e[++tot]=y,ne[tot]=head[x],w[tot]=z,head[x]=tot;
}
void dfs(int x)
{
dfn[x]=++num;
sz[x]=1;
for(int i=head[x];i;i=ne[i])
{
int y=e[i];
if(!d[y])
{
d[y]=d[x]+1;
f[y][0]=x;
for(int j=1;j<=t;j++) f[y][j]=f[f[y][j-1]][j-1];
dfs(y);
sz[x]+=sz[y];
}
}
}
int getlca(int x,int y)
{
if(d[x]>d[y]) swap(x,y);
for(int i=t;i>=0;i--)
if(d[f[y][i]]>=d[x])
y=f[y][i];
if(x==y) return x;
for(int i=t;i>=0;i--)
if(f[x][i]!=f[y][i])
x=f[x][i],y=f[y][i];
return f[x][0];
}
bool cmp(int x,int y)
{
return dfn[x]<dfn[y];
}
void build(int k)
{
cnt=0;
for(auto x:mp[k]) if(x!=1) a[++cnt]=x;
sort(a+1,a+cnt+1,cmp);
top=0; stk[++top]=1;
vec[1].clear();
for(int i=1;i<=cnt;i++)
{
int lca=getlca(a[i],stk[top]);
if(lca!=stk[top])
{
while(dfn[lca]<dfn[stk[top-1]])
{
int x=stk[top-1],y=stk[top];
vec[x].push_back({y,col[x].count(y)?col[x][y]:0});
top--;
}
if(dfn[lca]>dfn[stk[top-1]])
{
vec[lca].clear();
vec[lca].push_back({stk[top],col[lca].count(stk[top])?col[lca][stk[top]]:0});
top--;
stk[++top]=lca;
}
else
{
int x=stk[top-1],y=stk[top];
vec[x].push_back({y,col[x].count(y)?col[x][y]:0});
top--;
}
}
vec[a[i]].clear();
stk[++top]=a[i];
}
while(top>1)
{
int x=stk[top-1],y=stk[top];
vec[x].push_back({y,col[x].count(y)?col[x][y]:0});
top--;
}
}
void dp(int x,int p,int k)
{
g[x][1]=0; g[x][0]=sz[x]-1;
for(auto it:vec[x]) g[x][0]-=sz[it.first];
for(auto it:vec[x])
{
int y=it.first,z=it.second;
if(y==p) continue;
dp(y,x,k);
if(z==k)
{
ans+=(g[y][0]+1)*g[x][0];
g[x][1]+=(g[y][0]+1);
}
else
{
ans+=(g[y][0]+1)*g[x][1];
ans+=g[y][1]*g[x][0];
g[x][0]+=(g[y][0]+1);
g[x][1]+=g[y][1];
}
}
ans+=g[x][1];
}
int main()
{
cin>>n;
for(int i=1;i<n;i++)
{
int x,y,z;
cin>>x>>y>>z;
add(x,y,z);
add(y,x,z);
col[x][y]=col[y][x]=z;
mp[z].insert(x);
mp[z].insert(y);
}
while((1<<t)<n) t++;
d[1]=1; dfs(1);
for(int k=1;k<=n;k++)
{
if(mp[k].size())
{
build(k);
dp(1,0,k);
}
}
cout<<ans<<endl;
return 0;
}
分治+可撤销并查集
map<int,vector<pair<int,int>>> mp;
ll ans;
int n;
void dfs(int l,int r)
{
if(l==r)
{
for(auto it:mp[l])
{
ans+=(ll)dsu.siz(it.first)*dsu.siz(it.second);
}
return ;
}
int mid=(l+r)>>1;
int h=dsu.histroy();
for(int i=l;i<=mid;i++)
{
for(auto it:mp[i])
{
dsu.merge(it.first,it.second);
}
}
dfs(mid+1,r);
dsu.roll(h);
for(int i=mid+1;i<=r;i++)
{
for(auto it:mp[i])
{
dsu.merge(it.first,it.second);
}
}
dfs(l,mid);
dsu.roll(h);
}
int main()
{
cin>>n;
dsu.init(n);
for(int i=1;i<n;i++)
{
int x,y,z;
cin>>x>>y>>z;
mp[z].push_back({x,y});
}
dfs(1,n);
cout<<ans<<endl;
return 0;
}
const int N=5e5+10;
vector<pair<int,int>> vec[N];
int main()
{
int n; cin>>n;
for(int i=1;i<=n;i++) lct.tr[i].sz=1;
for(int i=1;i<n;i++)
{
int x,y,z;
cin>>x>>y>>z;
lct.link(x,y);
vec[z].push_back({x,y});
}
ll ans=0;
for(int i=1;i<=n;i++)
{
for(auto it:vec[i])
{
int x=it.first,y=it.second;
lct.cut(x,y);
}
for(auto it:vec[i])
{
int x=it.first,y=it.second;
lct.makeroot(x),lct.makeroot(y);
ans+=(ll)lct.tr[x].sz*lct.tr[y].sz;
}
for(auto it:vec[i])
{
int x=it.first,y=it.second;
lct.link(x,y);
}
}
cout<<ans<<endl;
return 0;
}