알고리즘/알고리즘 공부 내용 정리

FFT(고속 푸리에 변환, Fast Fourier Transform) 코드 (쿨리-튜키 + 빠른 코드)

restudy 2022. 4. 4. 22:53
반응형
#include <bits/stdc++.h>
using namespace std;

const double PI = acos(-1);
typedef long long ll;
typedef complex<double> cpx;

void FFT(vector<cpx> &f, cpx x) {
    int n = f.size();
    if(n == 1) return;

    vector<cpx> even(n/2), odd(n/2);
    for(int i=0; i<n; i++) {
        if(i % 2 == 0) even[i/2] = f[i];
        else odd[i/2] = f[i];
    }

    FFT(even, x*x);
    FFT(odd, x*x);

    cpx unit(1, 0);
    for(int i=0; i<n/2; i++) {
        f[i] = even[i] + unit*odd[i];
        f[i + n/2] = even[i] - unit*odd[i];

        unit *= x;
    }
}

vector<ll> multiply(vector<ll> &a, vector<ll> &b) {
    vector<ll> c(a.size() + b.size() - 1);

    int n = 1;
    while(n <= a.size() || n <= b.size()) n *= 2;
    n *= 2;

    a.resize(n);
    b.resize(n);

    vector<cpx> a_(n), b_(n);
    for(int i=0; i<n; i++) {
        a_[i] = cpx(a[i], 0);
        b_[i] = cpx(b[i], 0);
    }
    cpx unit(cos(2*PI/n), sin(2*PI/n));

    FFT(a_, unit);
    FFT(b_, unit);

    vector<cpx> c_(n);
    for(int i=0; i<n; i++) c_[i] = a_[i] * b_[i];

    FFT(c_, cpx(1, 0)/unit);
    for(int i=0; i<n; i++) c_[i] /= cpx(n, 0);

    for(int i=0; i<c.size(); i++) c[i] = round(c_[i].real());

    return c;
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL), cout.tie(NULL);

    vector<ll> a = {1, 2};
    vector<ll> b = {1, 3};

    vector<ll> c = multiply(a, b);
    for(int i=0; i<c.size(); i++) cout << c[i] << " ";
}

 

두 그룹 간의 이산 합성곱을 O(N log N) 시간에 빠르게 구할 수 있는 FFT(고속 푸리에 변환, Fast Fourier Transform) 알고리즘을 깔끔하게 정리해보았습니다.

 

물론 다른 사이트에도 많은 구현들이 존재하지만 바로 복사-붙여넣기 해서 사용할 수 있는 정형화 된 코드로 정리해두고 싶다는 생각이 들어서 작성하게 되었습니다.

참고로 이 포스트에서 구현되어 있는 코드는 정확하게는 쿨리-튜키(Cooley-Tukey) 알고리즘으로 FFT의 다양한 알고리즘들 중 대표적인 알고리즘입니다.

 

알고리즘에 대한 설명은 다른 포스트를 따로 작성하여 설명하는 기회를 가지도록 하겠습니다.

일단은 FFT에 대한 여러 응용 문제들을 풀어보고 나서 추가로 정리하려고 합니다.

 

 

코드를 응용하려면 main 함수의 a와 b 벡터의 값들을 적절히 고쳐서 사용하시면 됩니다.

예를 들어 위의 코드에는 a의 원소는 1, 2이고 b의 원소는 1, 3이며 실행해보면 c에는 1, 5 ,6이 들어가 출력이 되는데 이는 1 = 1 × 1, 5 = 1 × 3 + 2 × 1, 6 = 2 × 3을 의미합니다.

 


(내용 추가) 더 빠른 코드를 첨부합니다.

문제들을 풀어보니 위의 쿨리-튤키 알고리즘으로는 시간초과가 발생하여 해결이 불가능한 문제가 있었습니다.

따라서 아래와 같이 약간 더 빠른 코드를 정형화하여 정리해보았습니다.

 

 

#include <bits/stdc++.h>
using namespace std;

const double PI = acos(-1);
typedef complex<double> cpx;

void FFT(vector<cpx> &f, bool inv) {
    int N = f.size();

    for(int i=1, j=0; i<N; i++) {
        int bit = N/2;

        while(j >= bit) {
            j -= bit;
            bit /= 2;
        }
        j += bit;

        if(i < j) swap(f[i], f[j]);
    }

    for(int k=1; k<N; k*=2) {
        double angle = (inv ? PI/k : -PI/k);
        cpx dir(cos(angle), sin(angle));

        for(int i=0; i<N; i+=k*2) {
            cpx unit(1, 0);

            for(int j=0; j<k; j++) {
                cpx u = f[i+j];
                cpx v = f[i+j+k] * unit;

                f[i+j] = u + v;
                f[i+j+k] = u - v;

                unit *= dir;
            }
        }
    }

    if(inv)
        for(int i=0; i<N; i++) f[i] /= N;
}

vector<cpx> multiply(vector<cpx> &a, vector<cpx> &b) {
    int N = 1;
    while(N < a.size() + b.size()) N *= 2;

    a.resize(N); FFT(a, false);
    b.resize(N); FFT(b, false);

    vector<cpx> c(N);
    for(int i=0; i<N; i++) c[i] = a[i] * b[i];
    FFT(c, true);

    return c;
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL), cout.tie(NULL);

    vector<cpx> a = {cpx(1, 0), cpx(2, 0), cpx(3, 0)};
    vector<cpx> b = {cpx(1, 0), cpx(2, 0), cpx(3, 0)};

    int Size = a.size() + b.size() - 1;

    vector<cpx> c = multiply(a, b);
    for(int i=0; i<Size; i++) cout << round(c[i].real()) << " ";
}

 

같은 문제를 위의 두 알고리즘으로 풀이해본 결과 바로 위의 코드가 쿨리-튤키 알고리즘에 비해 속도는 대략 2.5배 정도 빠릅니다.

 

 

 

반응형