一般的线段树维护的是区间上的信息,比如区间最大值,最小值,区间和等,维护的区间都对应数组的一段连续的下标。
而权值线段树于此不同。权值线段树是指以值域为区间的线段树,它维护的区间是一段值域,所以称之为权值线段树。
因为权值线段树维护的是值域,而值域的范围可能很大,且值域中的许多位置是空的。我们可以省去这些空间,只将需要用的点表示出来即可。此时,叶节点的数量就是元素的数量。若元素值个数为 n n n,值域为 m m m,则每个叶节点的祖先最多 l o g m logm logm个,所以这棵树的节点数量也不会超过 n l o g m nlogm nlogm个。
线段树合并指将多棵线段树上的信息合并到一棵线段树上。每次合并两棵,合并后的信息就在其中一棵线段树上。
对于每次合并,我们从根节点往下依次遍历每个位置,如果某个位置两棵树都存在节点,则将一棵树上的信息加到另一棵树上,然后继续处理它们的左儿子和右儿子。如果某个位置只有一棵树有节点或者两棵树都没有节点,则只需处理当前位置,不需要再往下合并了。
void merge(int &r1,int r2,int l,int r){
if(!r1||!r2){
r1=r1+r2;return;
}
if(l==r){
s[r1]+=s[r2];return;
}
int mid=(l+r)/2;
merge(tr[r1].lc,tr[r2].lc,l,mid);
merge(tr[r1].rc,tr[r2].rc,mid+1,r);
s[r1]=s[tr[r1].lc]+s[tr[r1].rc];
}
这道题需要用到并查集和线段树合并。
开始时对每一座岛建立一棵线段树,第 i i i个叶节点为一表示重要程度为 i i i的岛与这个岛连通。维护区间叶节点数量 s s s, r t [ i ] rt[i] rt[i]表示节点 i i i的权值线段树的根,然后对 M M M条边进行合并。比如对于边 ( a , b ) (a,b) (a,b),如果 a , b a,b a,b不在一个连通块,则设 x = f i n d ( a ) , y = f i n d ( b ) x=find(a),y=find(b) x=find(a),y=find(b),将 x x x的线段树和 y y y的线段树进行合并,然后 f a [ y ] = x fa[y]=x fa[y]=x,表示将 y y y的线段树合并到 x x x的线段树中。
对于 B B B操作,同上,并查集处理即可。对于 Q Q Q操作,查找 x x x的线段树,找到第 k k k小的叶节点即可。
#include
using namespace std;
int n,m,q,tot=0,x,y,c[100005],re[100005],rt[100005],fa[100005],s[5000005];
char ch;
struct node{
int lc,rc;
}tr[5000005];
int find(int ff){
if(fa[ff]!=ff) fa[ff]=find(fa[ff]);
return fa[ff];
}
void pt(int &k,int l,int r,int v){
if(!k) k=++tot;
if(l==r&&l==v){
++s[k];return;
}
if(l>v||r<v) return;
if(l==r) return;
int mid=(l+r)/2;
if(v<=mid) pt(tr[k].lc,l,mid,v);
else pt(tr[k].rc,mid+1,r,v);
s[k]=s[tr[k].lc]+s[tr[k].rc];
}
void merge(int &r1,int r2,int l,int r){
if(!r1||!r2){
r1=r1+r2;return;
}
if(l==r){
s[r1]+=s[r2];return;
}
int mid=(l+r)/2;
merge(tr[r1].lc,tr[r2].lc,l,mid);
merge(tr[r1].rc,tr[r2].rc,mid+1,r);
s[r1]=s[tr[r1].lc]+s[tr[r1].rc];
}
void gt(int x,int y){
int x1=find(x),x2=find(y);
if(x1==x2) return;
merge(rt[x1],rt[x2],1,n);
fa[x2]=x1;
}
int find(int k,int l,int r,int v){
if(!k) return -1;
if(l==r){
if(v==1) return re[l];
return -1;
}
int mid=(l+r)/2;
if(v<=s[tr[k].lc]&&tr[k].lc) return find(tr[k].lc,l,mid,v);
return find(tr[k].rc,mid+1,r,v-s[tr[k].lc]);
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++){
scanf("%d",&c[i]);
re[c[i]]=i;fa[i]=i;
pt(rt[i],1,n,c[i]);
}
for(int i=1;i<=m;i++){
scanf("%d%d",&x,&y);
gt(x,y);
}
scanf("%d",&q);
while(q--){
ch=getchar();
while(ch!='B'&&ch!='Q') ch=getchar();
scanf("%d%d",&x,&y);
if(ch=='B'){
gt(x,y);
}
else{
x=find(x);
printf("%d\n",find(rt[x],1,n,y));
}
}
return 0;
}