如果对于每个点都去寻找与ta距离为2的点的话,就会超时,因为遍历边的次数要多得多。
所以换一种思路,对于每个点,考虑ta作为一个距离为2的点对的中转点,对于答案有什么贡献,这样子我们只需要枚举每个点的出边就行,时间复杂度降到 O ( n + m ) O(n+m) O(n+m)。
从特殊情况入手,当一个点有三条出边,权值分别为 a , b , c a,b,c a,b,c,那么联合权值的和就是 2 a b + 2 a c + 2 b c = ( a + b + c ) 2 − ( a 2 + b 2 + c 2 ) 2ab+2ac+2bc=(a+b+c)^2-(a^2+b^2+c^2) 2ab+2ac+2bc=(a+b+c)2−(a2+b2+c2),其实这样子变形是根据前一个式子的形态很容易想到的,所以我们只需要记录所有出边的和的平方,平方和,就可以解决问题。
其实关键在于能否想到中转点这一巧思!
#include
#include
#include
using namespace std;
typedef long long ll;
struct node
{
int to,next;
}e[400010];
const int mod=10007;
int n,mx,a[200010];
ll sum;
int tot,hd[400010];
void add(int x,int y)
{
e[++tot]=(node){y,hd[x]};
hd[x]=tot;
}
int main()
{
cin>>n;
for(int i=1;i<=n-1;i++)
{
int x,y;
scanf("%d%d",&x,&y);
add(x,y);
add(y,x);
}
for(int i=1;i<=n;i++)
{
scanf("%d",&a[i]);
}
for(int i=1;i<=n;i++)//O(2n)
{
int mx1=0,mx2=0;
ll sum1=0,sum2=0;
for(int j=hd[i];j;j=e[j].next)
{
if(a[e[j].to]>mx1)
{
mx2=mx1;
mx1=a[e[j].to];
}
else if(a[e[j].to]>mx2) mx2=a[e[j].to];
sum1+=a[e[j].to]%mod;
sum2+=a[e[j].to]*a[e[j].to]%mod;
}
mx=max(mx,mx1*mx2);
sum=sum+(sum1*sum1-sum2)%mod;
}
cout<<mx<<' '<<sum%mod;
return 0;
}