Polynomial related FFT&NTT&FWT

Recommended blog:

FFTFFTFFT: https://www.cnblogs.com/pks-t/p/9251147.html
Fast Fourier transform fast number theory transform fast Walsh transform: https://www.cnblogs.com/NaVi-Awson/p/8684389.html

Catalog:

FFTFFTFFT
NTTNTTNTT
Any modulus NTT any modulus NTT any modulus NTT
Nlgn to achieve high precision multiplication nlgn to achieve high precision multiplication nlgn to achieve high precision multiplication
FWTFWTFWT

Template:

It's impossible to understand. You can only copy the template and fool around.
Tips1: tips1: tips1: we must pay attention to the length of the array, which is very easy to cause problems!
Tips2: tips2: tips2: polynomial multiplication, then the result of the multiplication of the two subterms is: coefficient multiplication, index addition, which may be instructive in the topic!

Polynomial multiplication:

FFT:FFT:FFT:

#include<bits/stdc++.h>
using namespace std;        //FFT template

const int maxn=1e6+5;

struct Complex
{
    double x,y;
    Complex(double dx=0,double dy=0)
    {
        x=dx;
        y=dy;
    }
};

Complex operator +(Complex a,Complex b)
{
    return Complex(a.x+b.x,a.y+b.y);
}
Complex operator -(Complex a,Complex b)
{
    return Complex(a.x-b.x,a.y-b.y);
}
Complex operator *(Complex a,Complex b)
{
    return Complex(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);
}

const double pi=acos(-1.0); //PI
int limit,bit,len,n,m;
int wz[maxn<<2];
Complex a[maxn<<2],b[maxn<<2];

void FFT(Complex *A,int inv)
{
    for(int i=0;i<limit;i++)
        if(i<wz[i])
            swap(A[i],A[wz[i]]);
    for(int mid=1;mid<limit;mid<<=1)
    {
        Complex wn(cos(pi/mid),inv*sin(pi/mid));
        for(int i=0;i<limit;i+=mid<<1)
        {
            Complex w(1,0);
            for(int j=0;j<mid;j++,w=w*wn)
            {
                Complex t1=A[i+j];
                Complex t2=w*A[i+mid+j];
                A[i+j]=t1+t2;
                A[i+mid+j]=t1-t2;
            }
        }
    }
}

void work()
{
    limit=1,bit=0,len=n+m;
    while(limit<=len)
    {
        limit<<=1;
        bit++;
    }
    for(int i=0;i<limit;i++)
        wz[i]=(wz[i>>1]>>1)|((i&1)<<(bit-1));
    FFT(a,1);
    FFT(b,1);
    for(int i=0;i<limit;i++)
        a[i]=a[i]*b[i];
    FFT(a,-1);
    for(int i=0;i<=len;i++)
        a[i].x=a[i].x/limit+0.5;
}

int main()
{
    while(~scanf("%d%d",&n,&m)) //The length of two polynomials
    {
        memset(a,0,sizeof(a));
        memset(b,0,sizeof(b));
        int v;
        for(int i=0;i<=n;i++) //From low to high
        {
            scanf("%d",&v);
            a[i].x=v;
        }
        for(int i=0;i<=m;i++) //From low to high
        {
            scanf("%d",&v);
            b[i].x=v;
        }
        work();
        printf("%d",(int)(a[0].x));
        for(int i=1;i<=len;i++) //From low to high
            printf(" %d",(int)(a[i].x));
        putchar('\n');
    }
    return 0;
}

NTT:NTT:NTT:

#include<bits/stdc++.h>
using namespace std;        //NTT template
typedef long long ll;

const int maxn=1e6+5;
const int MOD=998244353; //modulus
const int G=3; //Primitive root

int limit,bit,len,n,m;
int wz[maxn<<2];
ll a[maxn<<2],b[maxn<<2];

inline ll qpow(ll x,ll y)
{
    ll t1=x,t2=1;
    while(y)
    {
        if(y&1)
            t2=(t1*t2)%MOD;
        t1=(t1*t1)%MOD;
        y>>=1;
    }
    return t2;
}

void NTT(ll *A,int inv)
{
    for(int i=0;i<limit;i++)
        if(i<wz[i])
            swap(A[i],A[wz[i]]);
    ll gn,t1,t2;
    for(int mid=1;mid<limit;mid<<=1)
    {
        gn=qpow(G,(MOD-1)/(mid<<1));
        if(inv==-1)
            gn=qpow(gn,MOD-2);
        for(int i=0;i<limit;i+=mid<<1)
        {
            ll g=1;
            for(int j=0;j<mid;j++,g=g*gn%MOD)
            {
                t1=A[i+j];
                t2=g*A[i+mid+j]%MOD;
                A[i+j]=(t1+t2)%MOD;
                A[i+mid+j]=(t1-t2+MOD)%MOD;
            }
        }
    }
}

void work()
{
    limit=1,bit=0,len=n+m;
    while(limit<=len)
    {
        limit<<=1;
        bit++;
    }
    for(int i=0;i<limit;i++)
        wz[i]=(wz[i>>1]>>1)|((i&1)<<(bit-1));
    NTT(a,1);
    NTT(b,1);
    for(int i=0;i<limit;i++)
        a[i]=a[i]*b[i]%MOD;
    NTT(a,-1);
    ll inv=qpow(limit,MOD-2); //Inverse element of length
    for(int i=0;i<=len;i++)
        a[i]=a[i]*inv%MOD;
}

int main()
{
    scanf("%d %d",&n,&m);
    for(int i=0;i<=n;i++) //From low to high
        scanf("%lld",&a[i]);
    for(int i=0;i<=m;i++) //From low to high
        scanf("%lld",&b[i]);
    work();
    printf("%lld",a[0]);
    for(int i=1;i<=len;i++)
        printf(" %lld",a[i]);
    return 0;
}

Any modulus NTT:NTT:NTT:

#include<bits/stdc++.h>
using namespace std;        //FFT template
typedef long long ll;

const int maxn=1e5+5;

int idx=0;//frequency
int n,m,realmod;
int limit,bit,len;
int wz[maxn<<2];
ll MOD[3]={167772161,998244353,1004535809};
int G[3]={3,3,3};
ll a[maxn<<2],b[maxn<<2],c[3][maxn<<2];
ll beg1[maxn<<2],beg2[maxn<<2];

inline ll fmul(ll a,ll b,ll p)
{
    if(a>=p)
        a%=p;
    if(b>=p)
        b%=p;
    return (a*b-(ll)((long double)a*b/p)*p+p)%p;
}

inline ll qpow(ll x,ll y,ll p)
{
    ll t1=x,t2=1;
    while(y)
    {
        if(y&1)
            t2=(t1*t2)%p;
        t1=(t1*t1)%p;
        y>>=1;
    }
    return t2;
}

inline void CRT()
{
    ll m1=MOD[0],m2=MOD[1],m3=MOD[2];
    for(int i=0;i<=len;i++)
    {
        ll M=m1*m2;
        ll ans=fmul(c[0][i]*m2,qpow(m2,m1-2,m1),M)+fmul(c[1][i]*m1,qpow(m1,m2-2,m2),M);
        ans%=M;
        ll t1=((c[2][i]-ans)%m3+m3)%m3;
        ll t2=t1*qpow(M%m3,m3-2,m3)%m3;
        ans%=realmod,M%=realmod,t2%=realmod;
        c[0][i]=(M*t2%realmod+ans)%realmod;
    }
}

void NTT(ll *A,int inv) //Chinese remainder theorem
{
    for(int i=0;i<limit;i++)
        if(i<wz[i])
            swap(A[i],A[wz[i]]);
    ll gn,t1,t2;
    for(int mid=1;mid<limit;mid<<=1)
    {
        gn=qpow(G[idx],(MOD[idx]-1)/(mid<<1),MOD[idx]);
        if(inv==-1)
            gn=qpow(gn,MOD[idx]-2,MOD[idx]);
        for(int i=0;i<limit;i+=mid<<1)
        {
            ll g=1;
            for(int j=0;j<mid;j++,g=g*gn%MOD[idx])
            {
                t1=A[i+j];
                t2=g*A[i+mid+j]%MOD[idx];
                A[i+j]=(t1+t2)%MOD[idx];
                A[i+mid+j]=(t1-t2+MOD[idx])%MOD[idx];
            }
        }
    }
}

void work()
{
    limit=1,bit=0,len=n+m;
    while(limit<=len)
    {
        limit<<=1;
        bit++;
    }
    for(int i=0;i<limit;i++)
        wz[i]=(wz[i>>1]>>1)|((i&1)<<(bit-1));
    while(idx<3)
    {
        memcpy(a,beg1,sizeof(ll)*limit);
        memcpy(b,beg2,sizeof(ll)*limit);
        NTT(a,1);
        NTT(b,1);
        for(int i=0;i<limit;i++)
            a[i]=a[i]*b[i]%MOD[idx];
        NTT(a,-1);
        ll inv=qpow(limit,MOD[idx]-2,MOD[idx]); //Inverse element of length
        for(int i=0;i<=len;i++)
            c[idx][i]=a[i]*inv%MOD[idx];
        ++idx;
    }
    CRT();
}

int main()
{
    scanf("%d %d %d",&n,&m,&realmod);
    for(int i=0;i<=n;i++) //From low to high
        scanf("%lld",&beg1[i]);
    for(int i=0;i<=m;i++) //From low to high
        scanf("%lld",&beg2[i]);
    work();
    printf("%lld",c[0][0]);
    for(int i=1;i<=len;i++)
        printf(" %lld",c[0][i]);
    return 0;
}

High precision multiplication:

#include<bits/stdc++.h>
using namespace std;        //FFT template
                            //High precision multiplication with different digits
const int maxn=1e6+5;

struct Complex //Plural class
{
    double x,y;
    Complex(double dx=0,double dy=0)
    {
        x=dx;
        y=dy;
    }
};

Complex operator +(Complex a,Complex b)
{
    return Complex(a.x+b.x,a.y+b.y);
}
Complex operator -(Complex a,Complex b)
{
    return Complex(a.x-b.x,a.y-b.y);
}
Complex operator *(Complex a,Complex b)
{
    return Complex(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);
}

const double pi=acos(-1.0); //PI
int limit,bit,len,n,m; //Limit is the final extended length limit = 1 < < bit
int wz[maxn<<2];
int re[maxn<<2]; //Storage results
Complex a[maxn<<2],b[maxn<<2];
char s1[maxn],s2[maxn];//Store two integers

void FFT(Complex *A,int inv)
{
    for(int i=0;i<limit;i++)
        if(i<wz[i])
            swap(A[i],A[wz[i]]);
    for(int mid=1;mid<limit;mid<<=1)
    {
        Complex wn(cos(pi/mid),inv*sin(pi/mid));
        for(int i=0;i<limit;i+=mid<<1)
        {
            Complex w(1,0);
            for(int j=0;j<mid;j++,w=w*wn)
            {
                Complex t1=A[i+j];
                Complex t2=w*A[i+mid+j];
                A[i+j]=t1+t2;
                A[i+mid+j]=t1-t2;
            }
        }
    }
}

void work()
{
    limit=1,bit=0,len=n+m;
    while(limit<=len)
    {
        limit<<=1;
        bit++;
    }
    for(int i=0;i<limit;i++)
        wz[i]=(wz[i>>1]>>1)|((i&1)<<(bit-1));
    FFT(a,1);
    FFT(b,1);
    for(int i=0;i<limit;i++)
        a[i]=a[i]*b[i];
    FFT(a,-1);
    for(int i=0;i<limit;i++)
        a[i].x=a[i].x/limit+0.5;
}

int main()
{
    while(~scanf("%s%s",s1,s2))
    {
        memset(a,0,sizeof(a));
        memset(b,0,sizeof(b));          //An n-digit decimal number can be regarded as an n-1 degree polynomial
        n=strlen(s1);
        m=strlen(s2);
        for(int i=n-1,j=0;i>=0;i--,j++)
        {
            a[j].x=s1[i]-48;
            a[j].y=0;
        }
        for(int i=m-1,j=0;i>=0;i--,j++)
        {
            b[j].x=s2[i]-48;
            b[j].y=0;
        }
        work();
        memset(re,0,sizeof(re));
        for(int i=0;i<=limit;i++)
        {
            re[i]+=(int)(a[i].x);
            if(re[i]>=10) //carry
            {
                re[i+1]+=re[i]/10;
                re[i]%=10;
                if(i==limit)
                    ++limit;
            }
        }
        while(!re[limit]&&limit>=1)//Remove high 0
            limit--;
        while(limit>=0)
            printf("%d",re[limit--]);
        printf("\n");
    }
    return 0;
}

FWT:FWT:FWT:

615 original articles published, praised 23, visited 30000+
His message board follow

Posted on Mon, 13 Jan 2020 07:29:53 -0800 by samafua