Splay
前言
Splay是一种维护平衡二叉树的算法。虽然它常数大,而且比较难打,但Splay十分方便,而且LCT需要用到。
约定
cnticnt_icnti:节点iii的个数
valival_ivali:节点iii的权值
sizisiz_isizi:节点iii的子树大小
chi,0/1ch_{i,0/1}chi,0/1:节点iii的左右儿子
faifa_ifai节点iii的父亲
rootrootroot当前的根节点
tottottot当前的节点数量
gt(x)gt(x)gt(x):返回xxx是左儿子还是右儿子
pushup(x)pushup(x)pushup(x):更新当前子树大小
int gt(int x){return ch[fa[x]][1]==x;
}
void pt(int x){siz[x]=cnt[x]+siz[ch[x][0]]+siz[ch[x][1]];
}
基本操作
旋转操作
- yyy是zzz的哪个儿子,xxx就是zzz的哪个儿子
- xxx是yyy的哪个儿子,yyy就是xxx的对应儿子的兄弟
- xxx是yyy的哪个儿子,yyy的那个儿子就是xxx的对应儿子的兄弟
void rot(int x){int y=fa[x],z=fa[y],k=gt(x);ch[z][gt(y)]=x,fa[x]=z;ch[y][k]=ch[x][!k],fa[ch[y][k]]=y;ch[x][!k]=y,fa[y]=x;pt(y);pt(x);
}
伸展操作
splay(x,g)splay(x,g)splay(x,g),表示将xxx旋转到ggg下面。
我们可以一直rotrotrot,但如果xxx的父亲不是ggg且xxx和xxx的父亲是同一边的儿子,则可以旋转父亲。先旋转父亲可以减少深度。
void splay(int x,int g=0){for(int y;fa[x]!=g;rot(x)){y=fa[x];if(fa[y]!=g) rot((gt(x)==gt(y))?y:x);}if(!g) root=x;
}
普通操作
find操作
找到值最接近xxx的点,并伸展到根。
void find(int x){if(!root) return;int u=root;while(ch[u][x>v[u]]&&x^v[u]) u=ch[u][x>v[u]];splay(u);
}
insert操作
插入值为xxx的点,需进行一下操作
- 找到插入点的位置
- 如果存在值为xxx的点,则加对应的cntcntcnt
- 否则新加一个点
- 把该节点伸展到根
void insert(int x){int u=root,fu=0;while(u&&v[u]!=x){fu=u,u=ch[u][x>v[u]];}if(u) ++cnt[u];else{u=++tot;if(fu) ch[fu][x>v[fu]]=u;fa[u]=fu;v[u]=x;cnt[u]=siz[u]=1;}splay(u);
}
前驱和后继
求xxx的前驱和后继
先find(x)find(x)find(x),那么前驱就是左子树中最大的一个,后继就是右子树中最小的一个。
int nxt(int x,int f){find(x);int u=root;if(v[u]>x&&f) return u;if(v[u]<x&&!f) return u;u=ch[u][f];while(ch[u][!f]) u=ch[u][!f];return u;
}
delete操作
删除值为xxx的点
首先找到xxx的前驱sucsucsuc和xxx的后继preprepre,然后
- splay(pre)splay(pre)splay(pre)
- splay(suc,pre)splay(suc,pre)splay(suc,pre)
然后sucsucsuc的左子树就是要删除的点,删除即可。
void dele(int x){int lt=nxt(x,0),nt=nxt(x,1);splay(lt);splay(nt,lt);int tx=ch[nt][0];if(cnt[tx]>1) --cnt[tx],splay(tx);else ch[nt][0]=0;
}
kth操作
找到排名为kkk的节点的权值
int kth(int k){int u=root,sn=0;for(;;){sn=ch[u][0];if(k>siz[sn]+cnt[u]) k-=siz[sn]+cnt[u],u=ch[u][1];else if(siz[sn]>=k) u=sn;else return v[u];}
}
例题
普通平衡树
对于操作1,2,4,5,6,可以用上述操作解决即可。对于操作3,可以find(x)find(x)find(x)将其置为根,然后xxx的排名就是它左子树的节点个数+1+1+1。
注意为了防止splaysplaysplay出锅,要在加上两个节点∞\infty∞和−∞-\infty−∞。注意这两个节点对操作的影响。
#include<iostream>
#include<cstdio>
#define N 500000
using namespace std;
int root,tot,t,cnt[N],v[N],siz[N],ch[N][2],fa[N];
int gt(int x){return ch[fa[x]][1]==x;
}//Return 0 or 1 means x is the left or right son
void pt(int x){siz[x]=cnt[x]+siz[ch[x][0]]+siz[ch[x][1]];
}//Update the x
void rot(int x){int y=fa[x],z=fa[y],k=gt(x);ch[z][gt(y)]=x,fa[x]=z;ch[y][k]=ch[x][!k],fa[ch[y][k]]=y;ch[x][!k]=y,fa[y]=x;pt(y);pt(x);
}//Rotate
void splay(int x,int g=0){for(int y;fa[x]!=g;rot(x)){y=fa[x];if(fa[y]!=g) rot((gt(x)==gt(y))?y:x);}if(!g) root=x;
}//Put the x under the g
void find(int x){if(!root) return;int u=root;while(ch[u][x>v[u]]&&x^v[u]) u=ch[u][x>v[u]];splay(u);
}//Find the closest node and put it under the root
void insert(int x){int u=root,fu=0;while(u&&v[u]!=x){fu=u,u=ch[u][x>v[u]];}if(u) ++cnt[u];else{u=++tot;if(fu) ch[fu][x>v[fu]]=u;fa[u]=fu;v[u]=x;cnt[u]=siz[u]=1;}splay(u);
}//Insert the x
//1.Find the root which should be inserted
//2.If there is a node as same as it,plus its cnt
//3.Else plus a node
//4.Make the new node be the root
int nxt(int x,int f){find(x);int u=root;if(v[u]>x&&f) return u;if(v[u]<x&&!f) return u;u=ch[u][f];while(ch[u][!f]) u=ch[u][!f];return u;
}//Find the suc or the pre of the x
//After finding x,then
//The pre is the biggest one in left tree
//The suc is the smallest one in right tree
void dele(int x){int lt=nxt(x,0),nt=nxt(x,1);splay(lt);splay(nt,lt);int tx=ch[nt][0];if(cnt[tx]>1) --cnt[tx],splay(tx);else ch[nt][0]=0;
}//Find the pre and the suc of the x
//Splay(pre),splay(suc,pre)
//Then delete the left son of the suc
int kth(int k){int u=root,sn=0;for(;;){sn=ch[u][0];if(k>siz[sn]+cnt[u]) k-=siz[sn]+cnt[u],u=ch[u][1];else if(siz[sn]>=k) u=sn;else return v[u];}
}
int main()
{insert(2147483647);insert(-2147483647);scanf("%d",&t);while(t--){int op,x;scanf("%d%d",&op,&x);if(op==1) insert(x);else if(op==2) dele(x);else if(op==3) find(x),printf("%d\n",siz[ch[root][0]]);else if(op==4) printf("%d\n",kth(x+1));else if(op==5) printf("%d\n",v[nxt(x,0)]);else printf("%d\n",v[nxt(x,1)]);}return 0;
}