给定一颗 n
n 个点的树,每条边有边权 v(|v|≤106)v(|v|≤106) ,要求删去其中任意 kk 条边,使得剩余联通块的直径之和最大。求出这个最大值。0≤k<n≤3×105,10s,1GB。
问题是怎么求直径?!直径不就是最大的链吗?!
发现原问题等价于选择 k+1 条互不相交的链使得链的总价值最大。
设 dpi,j,0/1/2 表示到 i 子树,已经选择了 j 条链,当前根上的选择情况分别为当前根不和父亲合并、向下连接一条链、向下连接形成两条链的情况时最大收益。
在加入一棵子树的时候合并答案:
dpu,2=max(dpu,2+dpv,0,dpu,1+dpv,1+edg−e)dpu,1=max(dpu,1+dpv,0,dpu,0+dpv,1+edg)dpu,0=dpu,0+dpv,0
在出子树的时候将 dpx,1,dpx,2 都合并到 dpx,0 表示 x 不和父亲合并的情况。
那么这样可以 O(nk) 转移。
看题解,可以了解到这个状态和 k 的关系是:随 k 增大,答案先增大后减小。
马后炮 yy 一下发现其实比较容易感性理解,太小了就没得选大边,太大了就不得不因为链不能重合舍弃大边。
那么在这个以选择链的数量为横坐标,当前最大收益为纵坐标的二维 DP 上,我们要得出以 k 为横坐标的点的答案。
如果我们直接去掉上面 DP 的 j 一维,我们能够得到对于所有 k 的收益最大值,可以通过二分斜率来找到 k 的值。
二分选择一条链的额外代价,去掉 DP 中“已经选择了几条链”的那一维,直接记录选择链的最大权值之和。同时需要维护一个计数器数组和 DP 一起转移统计已经选择了多少条链。
那么如果选择的链多了就增大选择一条链的额外代价,否则减少,知道刚好等于 k,那么最终权值就是 val+k×e,其中 e 是额外代价。
这样最终求出的最大值加上 e×k 就是答案了。
其中有几个地方需要注意:
- 二分时,如果答案选择的链的数量 ≥k 则更新 ans,因为链的数量为 n 总是能够取到,而数量极少则不一定能取到。
- 横坐标为 k 的点可能和 k−1,k+1 构成直线,不一定能够准确二分到 k,所以最终答案要加上e×k 而不是当前选择链的条数。
- 二分的值域要到 [−n×106,n×106]。
- #define Maxn 300005
- #define int long long
- int n,k,tot,curmuti;
- int hea[Maxn],nex[Maxn<<1],ver[Maxn<<1],edg[Maxn<<1];
- struct NODE
- {
- int val,hav;
- NODE(int _val=0,int _hav=0):val(_val),hav(_hav){};
- inline bool friend operator < (NODE x,NODE y)
- { return (x.val!=y.val)?x.val
- inline NODE friend operator + (NODE x,NODE y)
- { return NODE(x.val+y.val,x.hav+y.hav); }
- };
- NODE dp[Maxn][3];
- inline void add(int x,int y,int d){ ver[++tot]=y,nex[tot]=hea[x],hea[x]=tot,edg[tot]=d; }
- void dfs(int x,int fa)
- {
- dp[x][2]=max(dp[x][2],NODE(-curmuti,1));
- for(int i=hea[x];i;i=nex[i]) if(ver[i]!=fa)
- {
- dfs(ver[i],x);
- dp[x][2]=max(
- dp[x][2]+dp[ver[i]][0],
- dp[x][1]+dp[ver[i]][1]+NODE(edg[i]-curmuti,1));
- dp[x][1]=max(
- dp[x][1]+dp[ver[i]][0],
- dp[x][0]+dp[ver[i]][1]+NODE(edg[i],0));
- dp[x][0]=dp[x][0]+dp[ver[i]][0];
- }
- dp[x][0]=max(dp[x][0],max(dp[x][1]+NODE(-curmuti,1),dp[x][2]));
- }
- signed main()
- {
- n=rd(),k=rd()+1;
- for(int i=1,x,y,d;i
rd(),y=rd(),d=rd(),add(x,y,d),add(y,x,d); - int nl=-n*1000000,nr=n*1000000,ret=0;
- while(nl<=nr)
- {
- int mid=(nl+nr)>>1;
- memset(dp,0,sizeof(dp)),curmuti=mid,dfs(1,0);
- if(dp[1][0].hav>=k) ret=mid,nl=mid+1;
- else nr=mid-1;
- }
- memset(dp,0,sizeof(dp)),curmuti=ret,dfs(1,0);
- printf("%lld\n",dp[1][0].val+ret*k);
- return 0;
- }