给出一棵树,在树上选三个点,使得他们两两树上距离相等,问有多少种选法?
能成立的点对应该是如图这样的关系

选出的三个点是 ( x , y , z ) (x,y,z) (x,y,z),他们由一个中转点 T T T中转,使得 X T = Y T = Z T = k XT=YT=ZT=k XT=YT=ZT=k,即三个点分别从 T T T的三棵子树中选取
但由于树的特性,至少有一支成为父亲,暴力枚举子树的时间复杂度是难以接受的
于是考虑树形 d p dp dp,设 f [ u ] [ i ] f[u][i] f[u][i]代表在 u u u的子树内距离 u u u为 i i i的点有多少个, g [ u ] [ i ] g[u][i] g[u][i]代表在 u u u的子树内,存在多少对二元组 ( x , y ) (x,y) (x,y),令 l c a lca lca为 x , y x,y x,y的最近公共祖先,有 d i s ( x , l c a ) = d i s ( y , l c a ) dis(x,lca)=dis(y,lca) dis(x,lca)=dis(y,lca)并且 d i s ( u , l c a ) = d i s ( x , l c a ) − i dis(u,lca)=dis(x,lca)-i dis(u,lca)=dis(x,lca)−i成立。
为什么这样考虑?可以发现这样设置 g g g后, x , y x,y x,y被确定了,那么只需要去寻找 z z z,而 l c a lca lca距离 u u u是 d i s ( x , l c a ) − i dis(x,lca)-i dis(x,lca)−i,我们要找到另外子树的一个点 z z z,使 d i s ( z , l c a ) = d i s ( x , l c a ) dis(z,lca)=dis(x,lca) dis(z,lca)=dis(x,lca),那么只需要从 u u u出发,找到另外子树的距离 u u u长度为 i i i的点,即 f [ u ] [ i ] f[u][i] f[u][i]。
不妨将 g [ u ] [ i ] g[u][i] g[u][i]解释为一个等价形式,意为在 u u u的子树内,存在多少点对 ( x , y ) (x,y) (x,y),使得其他子树内距离 u u u为 i i i的点都可以成为 z z z
并且由于他们以深度为下标,那么可以考虑长链剖分来合并 d p dp dp
a n s + = g [ u ] [ 0 ] + ∑ v ∈ s u m u g [ v ] [ i ] ∗ f [ u ] [ i − 1 ] + g [ u ] [ i + 1 ] ∗ f [ v ] [ i ] ans+=g[u][0]+\sum_{v\in sum_{u}} g[v][i]*f[u][i-1]+g[u][i+1]*f[v][i] ans+=g[u][0]+v∈sumu∑g[v][i]∗f[u][i−1]+g[u][i+1]∗f[v][i]
g [ u ] [ i − 1 ] + = ∑ v ∈ s u m u g [ v ] [ i ] (1) g[u][i-1]+=\sum_{v\in sum_{u}} g[v][i] \tag{1} g[u][i−1]+=v∈sumu∑g[v][i](1)
g [ u ] [ i + 1 ] + = f [ v ] [ i ] ∗ f [ u ] [ i + 1 ] (2) g[u][i+1]+=f[v][i]*f[u][i+1]\tag{2} g[u][i+1]+=f[v][i]∗f[u][i+1](2)
f [ u ] [ i + 1 ] = f [ v ] [ i ] f[u][i+1]=f[v][i] f[u][i+1]=f[v][i]
以上的 i i i都是枚举 v v v的长链长,因此 j ∈ [ 0 , m a x 1 [ v ] − 1 ] j\in[0,max1[v]-1] j∈[0,max1[v]−1]
答案的计算不妨考虑结果点 z z z在哪里?如果恰好在 u u u,那么答案即 g [ u ] [ 0 ] g[u][0] g[u][0],如果不在 u u u,在另外的子树上,那么 a n s + = ∑ i = 0 m a x 1 [ v ] − 1 g [ u ] [ i ] ∗ f [ u ] [ i ] ans+=\sum_{i=0}^{max1[v]-1}g[u][i]*f[u][i] ans+=∑i=0max1[v]−1g[u][i]∗f[u][i],如果在当前子树上,他也会被 ( x , y ) (x,y) (x,y)在的子树比较小的时候归入不在那个时候的子树的情况,不需要重复计数。
g g g的转移 ( 1 ) (1) (1)可以理解为,到达了 v v v还需要 i i i的长度才能到达 z z z,那么在 u u u还需要的长度自然 − 1 -1 −1,于是 g [ u ] [ i − 1 ] = g [ v ] [ i ] g[u][i-1]=g[v][i] g[u][i−1]=g[v][i]
转移 ( 2 ) (2) (2),由于转移 u u u的时候,所有以 u u u的儿子为 l c a lca lca的情况都被添加到 g [ u ] g[u] g[u]了,那么只需要考虑以 u u u为 l c a lca lca的情况,此时如果 u u u为 l c a lca lca,那么按照定义, d i s ( x , l c a ) − i = d i s ( l c a , u ) = 0 dis(x,lca)-i=dis(lca,u)=0 dis(x,lca)−i=dis(lca,u)=0,于是 d i s ( x , l c a ) = d i s ( y , l c a ) = i dis(x,lca)=dis(y,lca)=i dis(x,lca)=dis(y,lca)=i,那么 g [ u ] [ i ] g[u][i] g[u][i]就是添加有多少点对 ( x , y ) (x,y) (x,y),满足不在 u u u的同一颗子树内,距离 u u u距离都为 i i i的点,这样按照前缀和合并子树的思想就可以计算出来
显然我们统计答案是用当前子树的答案和之前所有子树的答案计算,于是计算 a n s ans ans需要放在转移 f , g f,g f,g之前,这是合并子树前缀和计算的思想。
转移式中计算 a n s ans ans是还需要加上 g [ u ] [ 0 ] g[u][0] g[u][0],这项必须在计算完长儿子后,统计短儿子前加入,如果在最后加入,那么 g [ u ] [ 0 ] g[u][0] g[u][0]包含了 g [ v ] [ 1 ] g[v][1] g[v][1], g [ v ] [ 1 ] g[v][1] g[v][1]又包含了 g [ v v ] [ 2 ] g[vv][2] g[vv][2],而 g [ v ] [ 1 ] g[v][1] g[v][1]已经和 f [ u ] [ 1 ] f[u][1] f[u][1]计数了,所以不能在最后计数
按照转移式, f f f和 g ( 1 ) g(1) g(1)是可以 O ( 1 ) O(1) O(1)继承长儿子的,对 g g g, g [ s o n [ u ] ] [ i ] g[son[u]][i] g[son[u]][i]被写在 g [ u ] [ i − 1 ] g[u][i-1] g[u][i−1]上,那么 g [ s o n [ u ] ] = g [ u ] − 1 g[son[u]]=g[u]-1 g[son[u]]=g[u]−1
,同理 f [ s o n [ u ] ] = f [ u ] + 1 f[son[u]]=f[u]+1 f[son[u]]=f[u]+1,两个指针偏移方向不一样。一种解决方法是,只用一个数组供 f , g f,g f,g一起填写,每次分配多一倍的空间,保证有足够空间可以写。先分配 f f f再分配 g g g的顺序不能调换,除非被写数组的开始分配的点在末尾
f[v]=pp; pp+=max1[v]<<1;
g[v]=pp; pp+=max1[v]<<1;
#include
#define ll long long
using namespace std;
struct way
{
int to,next;
}edge[200005];
int cntt,head[100005];
void add(int u,int v)
{
edge[++cntt].to=v;
edge[cntt].next=head[u];
head[u]=cntt;
}
int n,max1[100005],son[100005],depth[100005];
void dfs1(int u,int fa)
{
for(int i=head[u];i;i=edge[i].next)
{
int v=edge[i].to;
if(v==fa) continue;
dfs1(v,u);
if(max1[v]>max1[son[u]]) son[u]=v;
}
max1[u]=max1[son[u]]+1;
}
ll ans=0,*f[200005],*g[200005],tmp1[400005],*pp=tmp1;
void dfs(int u,int fa)
{
if(son[u])
{
f[son[u]]=f[u]+1;
g[son[u]]=g[u]-1;
dfs(son[u],u);
}
f[u][0]=1; ans+=g[u][0];
for(int i=head[u];i;i=edge[i].next)
{
int v=edge[i].to;
if(v==fa||v==son[u]) continue;
f[v]=pp; pp+=max1[v]<<1;
g[v]=pp; pp+=max1[v]<<1;
dfs(v,u);
for(int j=0;j<max1[v];j++)
{
ans+=f[v][j]*g[u][j+1];
if(j) ans+=g[v][j]*f[u][j-1];
}
for(int j=0;j<max1[v];j++)
{
g[u][j+1]+=f[u][j+1]*f[v][j];
if(j) g[u][j-1]+=g[v][j];
f[u][j+1]+=f[v][j];
}
}
}
int main()
{
cin>>n;
for(int i=1;i<n;i++)
{
int u,v; scanf("%d%d",&u,&v);
add(u,v); add(v,u);
}
dfs1(1,0);
f[1]=pp; pp+=max1[1]<<1;
g[1]=pp; pp+=max1[1]<<1;
dfs(1,0);
cout<<ans;
return 0;
;
g[1]=pp; pp+=max1[1]<<1;
dfs(1,0);
cout<<ans;
return 0;
}