알고리즘/백준(BOJ) 문제풀이

BOJ 1067 이동, BOJ 22289 큰 수 곱셈 (3) (FFT, 고속 푸리에 변환)

restudy 2022. 6. 27. 16:57
반응형

이 포스트에서는 지난 포스트에서 다룬 고속 푸리에 변환(FFT, Fast Fourier Transform) 알고리즘을 직접 백준 BOJ 문제들에 응용하는 풀이들에 대해 다루어볼 것입니다.

사실 저는 아직 FFT의 응용에 익숙하지 않기 때문에 쉬운 문제들부터 풀어보도록 하겠습니다.

 

 

백준 BOJ 1067번 : 이동

문제 난이도 : Platinum I

알고리즘 분류 : 고속 푸리에 변환 (FFT)

 

 

길이 N인 수열 X와 Y가 있을 때, 수열들을 적당히 순환시켜 순서를 바꾼 뒤 문제에서 주어진 식 S를 계산했을 때 그 최댓값을 구하는 문제입니다.

 

식 S의 꼴을 보면 원소 순서대로 곱하는, 즉 내적 형태의 식을 가지고 있습니다.

따라서 FFT에서 수행되는 이산 합성 곱의 형태를 적절히 변형하면 문제에서 요구하는 식의 형태를 만들 수 있을 것이라는 사실을 유추할 수 있습니다.

 

 

 

항의 꼴을 최대한 비슷하게 만들어주기 위해 하나의 벡터를 뒤집어서 이산 합성곱을 수행해주면 위와 같이 여러의 항 중 하나의 항에 a1b1 + a2b2 + a3b3의 꼴이 얻어짐을 생각해볼 수 있습니다.

 

 

 

이 문제에서는 회전시킨 벡터와의 이산 합성곱 역시 구해야하므로, 하나의 벡터를 두 번 반복한 벡터를 만들어 이산 합성곱을 계산해보면, 위와 같이 우리가 원하는 항들이 모두 등장함을 알 수 있습니다.

나머지 항들은 빨간색 항에 포함되는 항들인데, 문제 조건에서 원소들의 값은 모두 0 이상의 정수라고 하였으므로 음수는 고려해줄 필요가 없으므로, 모두 같거나 작음을 알 수 있습니다. (따라서 무시해줄 수 있습니다.)

 

 

 

따라서 두 개의 벡터 중 하나는 2개를 연결한 꼴로 만들어주어 이산 합성곱을 FFT로 O(N log N) 시간에 수행해주고, 나온 벡터의 항들 중 최댓값을 구하면 답이 됩니다.

 

 

#include <bits/stdc++.h>
#define int long long
using namespace std;

const double PI = acos(-1);
typedef complex<double> cpx;

void FFT(vector<cpx> &v, bool inv) {
    int S = v.size();

    for(int i=1, j=0; i<S; i++) {
        int bit = S/2;

        while(j >= bit) {
            j -= bit;
            bit /= 2;
        }
        j += bit;

        if(i < j) swap(v[i], v[j]);
    }

    for(int k=1; k<S; k*=2) {
        double angle = (inv ? PI/k : -PI/k);
        cpx dir(cos(angle), sin(angle));

        for(int i=0; i<S; i+=k*2) {
            cpx unit(1, 0);

            for(int j=0; j<k; j++) {
                cpx a = v[i+j];
                cpx b = v[i+j+k] * unit;

                v[i+j] = a + b;
                v[i+j+k] = a - b;

                unit *= dir;
            }
        }
    }

    if(inv)
        for(int i=0; i<S; i++) v[i] /= S;
}

vector<cpx> mul(vector<cpx> &v, vector<cpx> &u) {
    int S = 1;
    while(S < max(v.size(), u.size())) S *= 2;

    v.resize(S); FFT(v, false);
    u.resize(S); FFT(u, false);

    vector<cpx> w(S);
    for(int i=0; i<S; i++) w[i] = v[i] * u[i];
    FFT(w, true);

    return w;
}

main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL), cout.tie(NULL);

    int N; cin >> N;
    vector<cpx> v(N*2), u(N);

    for(int i=0; i<N; i++) {
        int x; cin >> x;
        v[i] = v[i+N] = cpx(x, 0);
    }
    for(int i=0; i<N; i++) {
        int x; cin >> x;
        u[N-1-i] = cpx(x, 0);
    }

    vector<cpx> w = mul(v, u);

    int ans = 0;
    for(int i=0; i<w.size(); i++) ans = max(ans, (int)round(w[i].real()));

    cout << ans << "\n";
}

 

구현은 위와 같이 수행해줄 수 있습니다.

main 함수를 제외한 나머지 코드 부분들은 모두 FFT의 기본 코드를 그대로 작성한 것이며, 단순히 v, u 벡터에 complex의 실수 값만 적절히 넣어주면 됩니다.

 

 

 

 

백준 BOJ 22289번 : 큰 수 곱셈 (3)

문제 난이도 : Platinum I

알고리즘 분류 : 고속 푸리에 변환 (FFT)

 

이 문제를 풀이하면 백준 BOJ 15576번 : 큰 수 곱셈 (2) 역시 풀이할 수 있으므로 (시간 제한이 더 넉넉함), 조건이 더 강화된 22289번 문제를 풀이하도록 하겠습니다.

 

 

 

두 수의 곱을 출력하는 간단한 문제이지만, 두 수의 자릿수가 최대 100만 자리인데 비해 시간 제한이 1초로 엄격합니다.

따라서 단순 곱을 수행하는 O(N^2) 알고리즘보다 빠른 알고리즘을 이용해야 합니다.

 

생각해보면 이산 합성곱 자체가 두 수의 곱과 같은 순서로 수행되기 때문에, FFT를 이용하면 O(N log N) 시간에 수행되는 곱셈을 구현할 수 있습니다. (이 때 N은 둘 중 큰 수의 자릿수)

예를 들어 123 x 456을 계산한다고 할 때, 맨 앞 자릿수는 1x4가 되고, 그 다음 자릿수는 1x5 + 2x4 (일단은), 그 뒤는 차례로 1x6 + 2x5 + 3x4, 2x6 + 3x5, 3x6이 됩니다.

이제 자릿수에 맞게 올림만 처리해주면 56088이라는 값이 얻어지는 것입니다.

 

즉, 우리는 각 자릿수의 수를 벡터의 원소로 나누어 저장하고, 두 벡터 간 FFT를 수행한 뒤 올림 처리만 해주면 답을 얻을 수 있는 것입니다.

 

 

#include <bits/stdc++.h>
#define int long long
using namespace std;

const double PI = acos(-1);
typedef complex<double> cpx;

void FFT(vector<cpx> &v, bool inv) {
    int S = v.size();

    for(int i=1, j=0; i<S; i++) {
        int bit = S/2;

        while(j >= bit) {
            j -= bit;
            bit /= 2;
        }
        j += bit;

        if(i < j) swap(v[i], v[j]);
    }

    for(int k=1; k<S; k*=2) {
        double angle = (inv ? PI/k : -PI/k);
        cpx dir(cos(angle), sin(angle));

        for(int i=0; i<S; i+=k*2) {
            cpx unit(1, 0);

            for(int j=0; j<k; j++) {
                cpx a = v[i+j];
                cpx b = v[i+j+k] * unit;

                v[i+j] = a + b;
                v[i+j+k] = a - b;

                unit *= dir;
            }
        }
    }

    if(inv)
        for(int i=0; i<S; i++) v[i] /= S;
}

vector<cpx> mul(vector<cpx> &v, vector<cpx> &u) {
    int S = 2;
    while(S < v.size() + u.size()) S *= 2;

    v.resize(S); FFT(v, false);
    u.resize(S); FFT(u, false);

    vector<cpx> w(S);
    for(int i=0; i<S; i++) w[i] = v[i] * u[i];
    FFT(w, true);

    return w;
}

main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL), cout.tie(NULL);

    string a, b; cin >> a >> b;

    vector<cpx> v, u;
    for(int i=0; i<a.length(); i++) v.push_back(cpx(a[i] - '0', 0));
    for(int i=0; i<b.length(); i++) u.push_back(cpx(b[i] - '0', 0));

    reverse(v.begin(), v.end());
    reverse(u.begin(), u.end());

    vector<cpx> w = mul(v, u);

    vector<int> ans(w.size());
    for(int i=0; i<ans.size(); i++) ans[i] = round(w[i].real());

    for(int i=0; i<ans.size(); i++) {
        if(ans[i] < 10) continue;

        if(i < ans.size()-1) ans[i+1] += ans[i]/10;
        else ans.push_back(ans[i]/10);

        ans[i] %= 10;
    }

    reverse(ans.begin(), ans.end());

    int i=0; while(ans[i] == 0) i++;
    if(i >= ans.size()) cout << 0 << "\n";

    while(i < ans.size()) {
        cout << ans[i];
        i++;
    }
    cout << "\n";
}

 

올림 처리를 편하게 하고자 v, u 벡터에 원소를 저장한 뒤 이를 뒤집어서 FFT를 수행해줍니다.

이제 올림 처리를 할 것인데, (현재 설명하는 부분에서 수는 ans 벡터에 저장되어 있음) ans[i] 값이 10 이상인 경우가 올림 처리를 해야하는 경우에 해당됩니다.

이 경우 다음 자리에 (지금 배열이 뒤집혀있으므로) ans[i]/10, 즉 올림할 수를 올려주고, 원래 자리인 ans[i] = ans[i] % 10으로 처리해주면 됩니다.

 

그리고 예외 처리를 하나 해주어야 하는데, 바로 0 + 0이 수행된 경우 답이 0인데, leading zero를 무시해버리면 답이 출력되지 않습니다.

따라서 0인 경우만 따로 처리해주면 정답 처리를 받을 수 있습니다.

 

 

 

 

 

반응형