求最小生成树的常用算法有 Kruskal,Prim \text{Kruskal,Prim} Kruskal,Prim。 Kruskal \text{Kruskal} Kruskal 的时间复杂度是与边数相关的,而此题边数很多,所以考虑 Prim \text{Prim} Prim 的思想。
维护两个点集:已匹配点集个未匹配点集。每次找出两个点集的最短边,将边的未匹配点集点加入已匹配点集。
对于此题,考虑使用线段树维护两个点集的最短边权值。
m
a
x
0
r
t
max0_{rt}
max0rt 表示
r
t
rt
rt 所代表的区间中,未匹配的点中
x
x
x 值最大的。
m
a
x
1
r
t
max1_{rt}
max1rt 表示
r
t
rt
rt 所代表的区间中,已匹配的点中
x
x
x 值最大的。
m
i
n
0
r
t
min0_{rt}
min0rt 表示
r
t
rt
rt 所代表的区间中,未匹配的点中
x
x
x 值最小的。
m
i
n
1
r
t
min1_{rt}
min1rt 表示
r
t
rt
rt 所代表的区间中,已匹配的点中
x
x
x 值最小的。
a
n
s
r
t
ans_{rt}
ansrt 表示
r
t
rt
rt 所代表的区间中的两个点之间的最小边权,其中两个点不能是一个点集。
在 p u s h u p pushup pushup 过程中, m a x 0 , m a x 1 , m i n 0 , m i n 1 max0,max1,min0,min1 max0,max1,min0,min1 只需简单地从儿子区间更新。
a n s ans ans 先从儿子区间更新,再考虑两个分居两个区间的点之间的边权,容易想到要边权最小,一定是选右边最小的和左边最大的(注意两点要不同点集), a n s ans ans 用其更新即可。
每次找到最短边后,将边的未匹配点集点 x x x 加入已匹配点集,实现可以把 x x x 在线段树单点交换 m a x 0 , m a x 1 max0,max1 max0,max1 和 m i n 0 , m i n 1 min0,min1 min0,min1。(相当于将这个点的属性有未匹配变为匹配了)
时间复杂度 O ( n log n ) O(n\log n) O(nlogn)
具体实现参见代码
#include
using namespace std;
const int N=3e5+1,INF=1e9;
int n,a[N];
long long Ans;
struct node
{
int v,id;
node(){}
node(int a,int b){v=a,id=b;}
bool operator<(const node &a)const{
return v<a.v;
}
}max0[N<<2],max1[N<<2],min0[N<<2],min1[N<<2],ans[N<<2];
void pushup(int rt)
{
max0[rt]=max(max0[rt<<1],max0[rt<<1|1]);
max1[rt]=max(max1[rt<<1],max1[rt<<1|1]);
min0[rt]=min(min0[rt<<1],min0[rt<<1|1]);
min1[rt]=min(min1[rt<<1],min1[rt<<1|1]);
ans[rt]=min({ans[rt<<1],ans[rt<<1|1],min(node(min1[rt<<1|1].v-max0[rt<<1].v,max0[rt<<1].id),
node(min0[rt<<1|1].v-max1[rt<<1].v,min0[rt<<1|1].id))});
}
void build(int rt,int l,int r)
{
if(l==r){
max0[rt]=min0[rt]=node(a[l],l);
max1[rt]=node(-INF,-1);
min1[rt]=node(INF,-1);
ans[rt]=node(INF,-1);
return;
}
int mid=l+r>>1;
build(rt<<1,l,mid);
build(rt<<1|1,mid+1,r);
pushup(rt);
}
void update(int rt,int l,int r,int x)
{
if(l==r){
swap(max0[rt],max1[rt]),swap(min0[rt],min1[rt]);
return;
}
int mid=l+r>>1;
if(x<=mid) update(rt<<1,l,mid,x);
else update(rt<<1|1,mid+1,r,x);
pushup(rt);
}
int main()
{
freopen("mst.in","r",stdin);
freopen("mst.out","w",stdout);
cin.tie(0)->sync_with_stdio(0);
cin>>n;
for(int i=1;i<=n;i++) cin>>a[i];
build(1,1,n);
update(1,1,n,1);
for(int i=1;i<n;i++){
Ans+=ans[1].v;
update(1,1,n,ans[1].id);
}
cout<<Ans;
}