bzoj4566 找相同字符 [后缀自动机]

in 题目 with 0 comment

Description

给定两个字符串,求出在两个字符串中各取出一个子串使得这两个子串相同的方案数。两个方案不同当且仅当这两个子串中有一个位置不同。


Input

两行,两个字符串s1,s2,长度分别为n1,n2。1 <=n1, n2<= 200000,字符串中只有小写字母


Output

输出一个整数表示答案


Sample Input

aabb
bbaa


Sample Output

10


Solution

对第一个串建一个后缀自动机,第二个串在自动机上跑即可.特别简单的思路,唯一需要注意的是统计跑到某一位时怎样把之前所有可能匹配到的子串一并加起来,我们可以建一个total_pos[ ],记录的是匹配到当前节点所有可能的(与这个点相关的)子串匹配数,这样我们在work( )的过程中,每次将当前匹配成功的节点p的total_pos[fa[p]]以及v[p]*(len-maxl[fa[p]])(二式表示不完整的那部分又可以对应几组可匹配子串)的和加起来即可


Code

#include<bits/stdc++.h>
#define maxn 200010
#define maxt 400010
// #define DEBUG
using namespace std;
int n;
int sum[maxt],tmp[maxt];
long long total_pos[maxt];
char ch[maxn];

struct SAM
{
  int last,root,tot;
  int son[maxt][26],fa[maxt],maxl[maxt],r[maxt];

  void init() { last=root=tot=1; }
  int addnode(int x) { return maxl[++tot]=x,tot; }

  void add(int pos)
  {
    int x=ch[pos]-'a',np=addnode(pos),p=last;
    last=np,r[np]=1;
    for( ; p&&!son[p][x] ; p=fa[p] ) son[p][x]=np;
    if(!p) fa[np]=root;
    else
    {
      int q=son[p][x];
      if(maxl[q]==maxl[p]+1) fa[np]=q;
      else
      {
        int nq=addnode(maxl[p]+1);
        memcpy(son[nq],son[q],sizeof(son[q]));
        fa[nq]=fa[q];
        fa[q]=fa[np]=nq;
        for( ; son[p][x]==q ; p=fa[p] ) son[p][x]=nq;
      }
    }
  }

  void Tsort()
  {
    for(int i=1;i<=tot;++i) sum[maxl[i]]++;
    for(int i=1;i<=n;++i) sum[i]+=sum[i-1];
    for(int i=1;i<=tot;++i) tmp[sum[maxl[i]]--]=i;
#ifdef DEBUG
    for(int i=1;i<=tot;++i) printf("tmp[%d]=%d  maxl[tmp[%d]]=%d\n",i,tmp[i],i,maxl[tmp[i]]);
#endif
    for(int i=tot;i;i--) r[fa[tmp[i]]]+=r[tmp[i]];
    for(int i=1;i<=tot;++i)
      total_pos[tmp[i]]=total_pos[fa[tmp[i]]]+r[tmp[i]]*(maxl[tmp[i]]-maxl[fa[tmp[i]]]);
  }

  void build()
  {
    init();
    scanf("%s",ch+1),n=strlen(ch+1);
    for(int i=1;i<=n;++i) add(i);
#ifdef DEBUG
    cout<<"n="<<n<<endl;
#endif
    Tsort();
  }

  void work()
  {
    scanf("%s",ch+1),n=strlen(ch+1);
    int p=root,len=0;
    long long ans=0;
    for(int i=1;i<=n;++i)
    {
      int x=ch[i]-'a';
      if(son[p][x]) len++,p=son[p][x];
      else
      {
        for( ; p&&!son[p][x] ; p=fa[p] ) ;
        if(!p) len=0,p=root;
        else len=maxl[p]+1,p=son[p][x];
      }
      if(p!=root) ans+=total_pos[fa[p]]+r[p]*(len-maxl[fa[p]]);
    }
    printf("%lld\n",ans);
  }
} sam ;

int main()
{
#ifdef DEBUG
  freopen("in.txt","r",stdin);
  freopen("out.txt","w",stdout);
#endif
  sam.build();
  sam.work();
  return 0;
}


扫描二维码,在手机上阅读!
Responses