快速傅里叶变换(FFT)

in 知识点 with 0 comment

例题:UOJ#34. 多项式乘法

Description

给你两个多项式,请输出乘起来后的多项式。

Input

第一行两个整数 n 和 m,分别表示两个多项式的次数。
第二行 n+1 个整数,分别表示第一个多项式的 0 到 n 次项前的系数。
第三行 m+1 个整数,分别表示第一个多项式的 0 到 m 次项前的系数。

Output

一行 n+m+1 个整数,分别表示乘起来后的多项式的 0 到 n+m 次项前的系数。

Sample Input

1 2
1 2
1 2 1

Sample Output

1 4 5 2

explanation

$(1+2x)⋅(1+2x+x^2)=1+4x+5x^2+2x^3$

Hint

0≤n , m≤$10^5$

Solution

首先这道涉及多项式乘法,求相乘后新多项式的系数,可简记为:
$$C(x)=A(x)B(x)=\sum_{j=0}^{2n-2} (\sum_{k=0}^ja_k*b_{k-j})x^j$$
正常算法的时间复杂度很容易得到,即O($n^2$),显然无法再规定时间内得到正确解,怎么办呢?这里我们介绍一种可在O($nlgn$)复杂度下快速求出该答案的方法,即FFT.希望读者在阅读此文后可自行编写!

多项式

多项式的表示

单位复数根

定义

满足$w^n=1$的复数$w$,其单位复数根恰好有n个,分别为$e^{2\pi ik/n},k={0,1,2,\dots,n-1}$,由复数的指数形式定义$e^{iu}=cos(u)+isin(u)$可将其转化为 $y_k=cos(2 \pi k/n)+isin(2\pi k/n)$

基本性质

DFT离散傅里叶变换

这个算法的核心是利用了卷积定理$$ a\times b=DFT^{-1}{2n}(DFT{2n}(a)\cdot DFT_{2n}(b)) $$

本文最开始的例题UOJ#34,目标多项式的系数$c_k=\sum_{k=0}^ja_k*b_{k-j}$,熟悉的人可能都知道这实际上就是a,b的卷积,能用傅里叶变换求解的题目一般都可以被转化成类似这样的卷积的形式,大家一定要对这个式子足够熟悉!!!

$$ y_k=A(W_n^k)=\sum_{j=0}^{n-1}a_j\cdot W_n^{kj}=\sum_{j=0}^{n-1}a_j\cdot e^{\frac{2\pi i}{n}jk} $$
该算法的复杂度是O($n^2$)的,有没有适当变换使其结合一些复数根的性质加速此过程?答案是肯定的!

FFT快速傅里叶变换

递归

利用分治的思想将$A(x)=a_0+a_1x+a_2x^2+\dots+a_{n-1}x^{n-1}$分为下标为奇数和偶数的两部分:
$$ A^{[

FFT(a):  
    n=a.length()  
    if n==1:  
        return a  
    w_n=e^(pi*i/n)=complex(cos(2*pi/n),sin(2*pi/n))  
    w=1  
    a(0)=[a0,a2,...a_n-2]  
    a(1)=[a1,a3,...a_n-1]  
    y(0)=FFT(a(0))  
    y(1)=FFT(a(1))  
    for k in range(0,n/2):  
        y_k=y_k(0)+w*y_k(1)                     //w*y_k(1)为公用子表达式 
        y_k+n/2=y_k(0)-w*y_k(1)  
        w=w*w_n                                 //w为旋转因子
    return y  

但递归的常数是很大的,我们是否可以进一步优化常数呢?只要将递归过程改为迭代的过程就好了!

迭代

inline int rev(int x,int n)                 //x为当前处理的待改变的数,n为二进制位的总长度(按上例则n=3)
{
    int x0=0;
    while(n--) x0=(x0+(x&1))<<1,x>>=1;
    return x0>>1;
}

因此只要知道出$y^{[

for k in range(0,n/2):  
    t=w*y_k(1)  
    y_k=y_k(0)+t  
    y_k+n/2=y_k(0)-t  
    w=w*w_n 

傅里叶逆变换公式

以上我们了解到如何将系数表示转换为点值表示,通过点值表示在O(n)复杂度下求出多项式的乘积之后只要再将点值表示转换为系数表示(求插值)即可.前面讲多项式的点值表达时我们提到了一种求插值的过程,$a=V(x_0,x_1,x_2,\dots,x_{n-1})^{-1}\cdot y$ , 即只要得到范德蒙德行列式的逆矩阵就能求出对应的a.

由于一个矩阵的逆矩阵$A^{-1}=\frac{1}{|A|}A^*$,易推得傅里叶逆变换公式:
$$ a_k=\frac{1}{n}\sum_{j=0}^{n-1}y^j\cdot e^{-\frac{2\pi i}{n}jk} $$
除了这种求逆矩阵的方法,我们还可以用拉格朗日公式求插值,但复杂度为O($n^2$),公式如下:
$$ A(x)=\sum_{k=0}^{n-1}y_k\frac{ \prod_{j\neq k}(x-x_j) }{ \prod_{j\neq k}(x_j-x_k) } $$

完整代码

大家最想要的代码来了,UOJ#34 AC代码:

#include<bits/stdc++.h>
#define pi acos(-1.0)
#define maxn 300010
//#define DEBUG                                     //DEBUG无视就好
using namespace std;
int n,m;
complex<double> a[maxn],b[maxn];

inline int read()                                   //读入优化
{
    char ch;
    int read=0;
    int sign=1;
    do
        ch=getchar();
    while((ch<'0'||ch>'9')&&ch!='-');
    if(ch=='-') sign=-1,ch=getchar();
    while(ch>='0'&&ch<='9')
    {
        read=read*10+ch-'0';
        ch=getchar(); 
    } 
    return read*sign;
}

int Power2(int x)                                                //把x转化为2的整数次幂
{
    int x0;
    for(x0=1;x0<=x;x0<<=1) ;
    return x0;
}

inline int lg(int n)                                             //计算二进制位数
{
    int l=0;
    if(n==0) return l;
    for(int x=1;x<=n;x<<=1) l++;
    return l;
}

inline int rev(int x,int n)                                       //位逆序置换
{
    int x0=0;
    while(n--) x0=(x0+(x&1))<<1,x>>=1;
    return x0>>1;
}

void FFT(complex<double> a[],int n,int flag)    //主体
{
    complex<double> A[n+1];
    for(int i=0,l=lg(n-1);i<n;++i) A[rev(i,l)]=a[i];
    #ifdef DEBUG
    int l=lg(n-1);                                               //切记是lg(n-1)
    cerr<<"l="<<l<<endl;
    for(int i=0;i<n;++i) cerr<<rev(i,l)<<" ";
    cerr<<endl;
    #endif 
    for(int i=2;i<=n;i<<=1)                                     //枚举合并后序列长度
    {
        complex<double> dw(cos(2*pi/i),sin(flag*2*pi/i));
        for(int j=0;j<n;j+=i)                                   //该长度下每部分进行求解
        {
            complex<double> w(1.0,0);
            for(int k=0;k<(i>>1);k++,w=w*dw)                    //蝴蝶变换,只需求i>>1次即可
            {
                complex<double> u=A[j+k];
                complex<double> t=w*A[j+k+(i>>1)];
                A[j+k]=u+t;
                A[j+k+(i>>1)]=u-t;
            }
        }
        if(flag==-1)
            for(int i=0;i<n;++i) a[i]=int(A[i].real()/n+0.5);
        else
            for(int i=0;i<n;++i) a[i]=A[i];
    }
}

int main()
{
    #ifdef DEBUG
    freopen("in.txt","r",stdin);
    #endif
    n=read();
    m=read();
    for(int i=0;i<=n;++i) a[i]=read();
    for(int i=0;i<=m;++i) b[i]=read();
    int length=Power2(n+m);
    #ifdef DEBUG
    cerr<<"length="<<length<<endl;
    #endif
    FFT(a,length,1);
    FFT(b,length,1);
    for(int i=0;i<=length;++i) a[i]*=b[i];
    FFT(a,length,-1);
    for(int i=0;i<=n+m;++i) printf("%d ",int(a[i].real()));
    return 0;
}
/*FFT高精度*/

include<bits/stdc++.h>
#define PI acos(-1.0)
#define eps 1e-1
#define maxn 200005
#define DEBUG
using namespace std;
int n,m,l=0;
int rev[maxn],ans[maxn];
char x[maxn],y[maxn];

struct Complex
{
  double real,imag;
  Complex(double real=0,double imag=0):real(real),imag(imag) {}
  Complex operator + (const Complex rhs)
  {
    return Complex(real+rhs.real,imag+rhs.imag);
  }
  Complex operator - (const Complex rhs)
  {
    return Complex(real-rhs.real,imag-rhs.imag);
  }
  Complex operator * (const Complex rhs)
  {
     return Complex((real*rhs.real-imag*rhs.imag),(real*rhs.imag+imag*rhs.real));
  }
};
Complex a[maxn],b[maxn];

inline int read()
{
  char ch;
  int read=0,sign=1;
  do
    ch=getchar();
  while((ch<'0'||ch>'9')&&ch!='-');
  if(ch=='-') sign=-1,ch=getchar();
  while(ch>='0'&&ch<='9')
  {
    read=read*10+ch-'0';
    ch=getchar();
  }
  return sign*read;
}

void pre_work()
{
  int length1,length2;
  scanf("%s",x);length1=strlen(x);
  scanf("%s",y);length2=strlen(y);
  n=max(length1,length2);
  for(int i=0;i<length1;++i) a[i].real=x[length1-i-1]-'0';
  for(int i=0;i<length2;++i) b[i].real=y[length2-i-1]-'0';
#ifdef DEBUG
  for(int i=0;i<n;++i) cerr<<a[i].real<<" ";
  cerr<<endl;
  for(int i=0;i<n;++i) cout<<b[i].real<<" ";
  cerr<<endl;
#endif
  m=2*n;
  for(n=1;n<m;n<<=1) l++;
  for(int i=0;i<n;++i) rev[i]=rev[i>>1]>>1|(i&1)<<(l-1);
#ifdef DEBUG
  for(int i=0;i<n;++i) cerr<<i<<"-->"<<rev[i]<<endl;
#endif
}

void FFT(Complex a[],int n,int sign)
{
  for(int i=0;i<n;++i)
    if(rev[i]<i) swap(a[i],a[rev[i]]);
  for(int i=2;i<=n;i<<=1)
  {
    Complex dw(cos(2*PI/i),sin(2*PI*sign/i));
    for(int j=0;j<n;j+=i)
    {
      Complex w(1,0);
      for(int k=0;k<(i>>1);k++,w=dw*w)
      {
        Complex u=a[j+k];
        Complex t=a[j+k+(i>>1)]*w;
        a[j+k]=u+t;
        a[j+k+(i>>1)]=u-t;
      }
    }
  }
  if(sign==-1)
    for(int i=0;i<n;++i) ans[i]=int(a[i].real/n+eps);
}

void push_ans()
{
  for(int i=0;i<n;++i)
    if(ans[i]>=10) ans[i+1]+=ans[i]/10,ans[i]%=10;
  int first=n-1;
  while(ans[first]==0) first--;
  for(int i=first;i>-1;i--) printf("%d",ans[i]);
}

int main()
{
  pre_work();
  FFT(a,n,1);
  FFT(b,n,1);
  for(int i=0;i<n;++i) a[i]=a[i]*b[i];
  FFT(a,n,-1);
#ifdef DEBUG
  for(int i=0;i<n;++i) cerr<<ans[i]<<" ";
  cerr<<endl;
#endif
  push_ans();
  return 0;
}

参考资料

http://blog.csdn.net/oiljt12138/article/details/54810204
http://blog.csdn.net/iamzky/article/details/22712347
算法导论第三十章


扫描二维码,在手机上阅读!
Responses