알고리즘/백준(BOJ) 문제풀이

백준 BOJ 13575번 : 보석 가게 풀이 (FFT, 고속 푸리에 변환)

restudy 2022. 6. 29. 18:06
반응형

백준 BOJ 13575번 : 보석 가게

문제 난이도 : Platinum I

알고리즘 분류 : FFT (고속 푸리에 변환), 분할 정복을 이용한 거듭 제곱

 

 

 

N개의 값 중에서 K개를 중복을 허용하여 선택할 때, 얻어질 수 있는 합의 가짓수를 구하는 문제입니다.

예를 들어 N = 3, 주어진 배열 A = { 1, 2, 3 }, K = 2라고 할 때 1 + 1 = 2, 1 + 2 = 3, 1 + 3 = 2 + 2 = 4, 2 + 3 = 5로 총 4가지 합을 얻을 수 있어서 답은 4가 됩니다. (문제 조건에 따라 다른 조합으로 같은 합을 얻더라도 하나의 합으로 카운트 해주어야 합니다.)

 

FFT를 이용하면 두 집합의 합으로 얻어질 수 있는 값들의 종류와 그 가짓수를 구할 수 있으므로, FFT를 활용하여 풀이하는 방법을 생각해볼 수 있습니다.

 

 

 

 

수 K개의 합을 고려해야 하므로 범위를 생각해보면, 원소 a_i는 최대 1000이므로 이산 합성곱이 연산되지 않은 벡터는 최대 1000개 정도의 항을 가져야하고, K가 최대 1000이므로 이 벡터를 1000번 제곱했다고 생각하면 벡터의 크기는 대략 1,000,000이 됩니다.

이는 충분히 메모리를 초과시키지 않는 범위이므로 구현이 가능함을 알 수 있습니다.

 

이제 시간 복잡도를 생각해보면, 이산 합성곱이 반복되면서 벡터의 크기 역시 길어지는데 이를 굳이 고려하지 않고 넉넉하게 처음부터 (a_i의 최대 크기) × K = 1000K라고 생각해봅시다. (a_i의 최댓값을 K번 반복해서 선택하면 그 최댓값이 K × a_i가 되는데 이를 벡터의 주소에 count 해주어야 하므로)

FFT의 시간 복잡도는 O(N log N)이므로 현재까지는 대략 O(1000N × log(1000N))입니다.

 

그런데, 우리는 이러한 연산을 K번 반복해주어야 합니다.

그러면 시간 복잡도는 O(1000KN log(1000N))이 되겠네요. K와 N 모두 최대 1000이므로 10억 × log(100만) = 60억이면 시간 초과가 나기 충분합니다.

따라서 여기서 시간을 더 줄이기 위해, K번의 거듭제곱을 분할 정복을 이용한 거듭 제곱을 활용하여 log 시간에 수행해줍시다.

 

그러면 최종 시간 복잡도는 O(1000N log(1000N) log(K))가 되어 시간 안에 통과할 수 있게 됩니다.

 

 

#include <bits/stdc++.h>
#define int long long
using namespace std;

const double PI = acos(-1);
typedef complex<double> cpx;

void FFT(vector<cpx> &v, bool inv) {
    int S = v.size();

    for(int i=1, j=0; i<S; i++) {
        int bit = S/2;

        while(j >= bit) {
            j -= bit;
            bit /= 2;
        }
        j += bit;

        if(i < j) swap(v[i], v[j]);
    }

    for(int k=1; k<S; k*=2) {
        double angle = (inv ? PI/k : -PI/k);
        cpx w(cos(angle), sin(angle));

        for(int i=0; i<S; i+=k*2) {
            cpx z(1, 0);

            for(int j=0; j<k; j++) {
                cpx even = v[i+j];
                cpx odd = v[i+j+k];

                v[i+j] = even + z*odd;
                v[i+j+k] = even - z*odd;

                z *= w;
            }
        }
    }

    if(inv)
        for(int i=0; i<S; i++) v[i] /= S;
}

vector<int> mul(vector<int> &v, vector<int> &u) {
    vector<cpx> vc(v.begin(), v.end());
    vector<cpx> uc(u.begin(), u.end());

    int S = 2;
    while(S < v.size() + u.size()) S *= 2;

    vc.resize(S); FFT(vc, false);
    uc.resize(S); FFT(uc, false);

    for(int i=0; i<S; i++) vc[i] *= uc[i];
    FFT(vc, true);

    vector<int> w(S);
    for(int i=0; i<S; i++)
        if(round(vc[i].real()) != 0) w[i] = 1;

    return w;
}

main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL), cout.tie(NULL);

    int N, M; cin >> N >> M;

    vector<int> v(1024);

    for(int i=0; i<N; i++) {
        int x; cin >> x;
        v[x] = 1;
    }

    vector<int> u(1, 1);

    while(M > 0) {
        if(M % 2 == 1) u = mul(u, v);

        v = mul(v, v);
        M /= 2;
    }

    for(int i=0; i<u.size(); i++)
        if(u[i] != 0) cout << i << " ";
    cout << "\n";
}

 

구현 자체는 위에서 언급한대로 하면 거의 해결이 됩니다.

 

다만 여전히 확실히 이해가 가지 않는 부분은 벡터의 크기를 2^k 꼴로 잡지 않으면 메모리 초과 또는 WA를 받는다는 것인데, 추정되는 원인으로는 'FFT 코드에서 반드시 벡터를 2^k 꼴로 resize를 하기 때문에, 불필요한 메모리가 계속해서 생기고 이것이 기하급수적으로 늘어나서' 정도가 있을 것 같습니다.

 

나머지 부분은 FFT 코드와, main 함수에서 log 시간에 수행이 가능한 거듭제곱의 구현 부분뿐입니다.

분할정복을 이용한 거듭제곱의 구현은 다른 포스트에서 여러 번 다루었으니 블로그 내 검색을 통해 참고하시면 좋을 것 같습니다.

특별히 신경써주면 좋은 점은, FFT를 반복함에 따라 원소의 값이 기하급수적으로 증가하기 때문에 overflow를 방지하고자 매 FFT 연산마다 원소의 값이 0보다 큰 것은 안전하게 모두 1로 바꿔주는 것입니다.

 

 

 

 

 

반응형