fail树是由KMP算法引申出来的概念,在解释fail树之前,首先要讲一个概念:Border。
接下来就可以讲fail树了,这里先不讲引入fail树的原因,先说fail树是个啥。
根据前面所说,我们就能推出以下结论,S的两个前缀S[1,p]和S[1,q]的公共最长Border长度,就是lca(ne[p], ne[q])。而这道模板题就是要用到这个基础结论:【模板】失配树 - 洛谷
代码如下:
- #include
- using namespace std;
- #define FOR(i, a, b) for (int i = (a); i <= (b); i++)
- // #define int long long
- #define pii pair
- const int N = 1e6+5, mod=1e9+7;
- char s[N]; int n,m;
- int ne[21][N], d[N];
- int lg[N];
-
- int lca(int x,int y){
- if(d[x] < d[y]) swap(x,y);
- while(d[x] > d[y]) x = ne[lg[d[x]-d[y]]][x];
- if(x==y) return y;
- for(int k=lg[d[x]]; k>=0; k--){
- if(ne[k][x] != ne[k][y]){x=ne[k][x]; y=ne[k][y];}
- }
- return ne[0][x];
- }
- void solve(){
- cin>>(s+1); n=strlen(s+1);
- //init of lg[]
- FOR(i,2,n) lg[i]=lg[i>>1]+1;
- //get_ne
- for(int i=2,j=0; i<=n; i++){
- while(j && s[i]!=s[j+1]) j=ne[0][j];
- if(s[i]==s[j+1]) j++;
- ne[0][i] = j, d[i]=d[j]+1; //记录next和深度d
- }
- //预处理倍增跳
- FOR(j,1,20) FOR(i,1,n)
- ne[j][i] = ne[j-1][ne[j-1][i]];
- //处理询问
- cin>>m;
- FOR(i,1,m){
- int x,y; cin>>x>>y;
- cout<<lca(ne[0][x], ne[0][y])<<'\n';
- }
- }
- signed main(){
- ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
- int T=1;
- while(T--) solve();
- }
再补充一道fail树的应用题:[NOI2014] 动物园 - 洛谷
大致题意是要求字符串S所有前缀 S[1,x] (1≤x≤n)的长度不大于x/2的Border数量。
我们这样考虑问题:如果没有长度不大于x/2的要求,那就很简单,求一个点的祖先数量就行,很容易预处理。而加上这个条件之后,根据fail树数字大小的单调性(ne[x]<x),我们也能知道,符合条件的祖先是“上面的连续一段”。所以我们依然可以直接预处理每个点的祖先数量(其实就是深度),然后不断跳fail,直到找到第一个长度不大于x/2的,它的祖先数量就是当前的答案。
这个思路没错,但是复杂度不ok,因为暴力跳fail是O(n)的,再算上n次询问,总复杂度O(n^2),过不了。所以把暴力跳fail改成倍增跳,优化成O(nlogn)就能过了。
代码如下:
- #include
- using namespace std;
- #define FOR(i, a, b) for (int i = (a); i <= (b); i++)
- // #define int long long
- #define pii pair
- const int N = 1e6+5, mod=1e9+7;
- char s[N]; int n;
- int ne[21][N], num[N];
-
- void solve(){
- //init
- memset(num,0,sizeof(num));
- num[1] = 1;
- //input
- cin>>(s+1); n=strlen(s+1);
- //get_ne
- for(int i=2,j=0; i<=n; i++){
- while(j && s[i]!=s[j+1]) j=ne[0][j];
- if(s[i]==s[j+1]) j++;
- ne[0][i] = j;
- num[i] = num[j]+1;
- }
- //预处理倍增跳
- FOR(j,1,20) FOR(i,1,n)
- ne[j][i] = ne[j-1][ne[j-1][i]];
- //跳fail到合适位置,取出答案
- long long ans = 1;
- FOR(i,1,n){
- int tt = ne[0][i];
- for(int j=20; j>=0; j--)
- if((ne[j][tt]<<1) > i) tt=ne[j][tt];
- // if((ne[tt][j]<<1) > i) tt=ne[tt][j];
- if((tt<<1) > i) tt = ne[0][tt];
- ans = (ans*(num[tt]+1))%mod;
- }
- cout<
'\n'; - }
- signed main(){
- ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
- int T=1; cin>>T;
- while(T--) solve();
- }