알고리즘/백준(BOJ) 문제풀이

백준 BOJ 13055번 : K-Inversions 풀이 (FFT, 고속 푸리에 변환)

restudy 2022. 6. 29. 23:13
반응형

백준 BOJ 13055번 : K-Inversions

문제 난이도 : Diamond V

알고리즘 분류 : FFT (고속 푸리에 변환)

 

 

B가 A보다 먼저 나오는 쌍을 inversion이라고 할 때, B와 A 사이의 거리가 K일 때 그 쌍을 K-Inversion이라고 정의합시다.

A와 B로만 이루어진 문자열에 대해 0 ~ N-1사이의 모든 K-inversion의 수를 구하는 문제입니다.

 

예를 들어 문제에서 예시로 든 BABA를 가지고 설명하자면, 앞의 BA와 뒤의 BA는 모두 B → A 순서대로 나타나며 B와 A 사이의 거리가 1이므로 1-inversion이고, 이는 총 2회 나옵니다.

그리고 맨 앞의 B와 맨 뒤의 A 역시 inversion에 해당되고, 두 문자 사이의 거리는 3이므로 3-inversion은 1개 나타납니다.

따라서 1, 2, 3-inversion의 수를 순서대로 출력하면 되므로 답은 2 0 1이 됩니다.

 

 

 

우선 드는 생각은 1 ~ N-1 사이의 K-inversion의 수를 모두 구해야하므로, 결국은 모두 일대일로 대조하면서 카운트해주는 방식이 유일해보입니다.

그런데 이렇게 할 경우 문자열의 길이가 최대 100만자리나 되므로 시간 초과를 받을 수밖에 없습니다.

따라서 당연히도 효율적인 알고리즘을 생각해내야 합니다.

 

 

 

해볼 수 있는 생각은, FFT의 이산 합성곱의 형태는 일반적인 곱셈식에서 각 자릿수가 가지는 관계와 비슷하므로, 곱셈식의 형태를 활용해보는 것입니다.

예를 들어 abc x def를 계산하면 위와 같이 아랫줄에 위치한 수의 각 자릿수가 윗줄에 위치한 수의 각 자릿수에 대응되면서 계산이 된다는 원리를 가지고 있습니다.

 

곱해지는 수와 줄의 관계를 잘 생각해보면, 곱셈식에서 각 열에 대응되는 수들의 매치는 아랫줄의 문자 순서를 뒤집어서 한 칸씩 이동했을 때 위아래로 겹치는 문자들의 쌍과 같습니다.

이러한 원리를 응용하면 inversion들의 counting이 충분히 가능할 것이라는 아이디어를 구상할 수 있습니다.

 

 

 

조금 구체적인 예시로 문자열 S = BABAA를 이용해보겠습니다.

BABAA는 그대로 두고, 아랫줄에 BABAA를 뒤집은 문자열 AABAB를 이용하여 한 칸씩 이동하면서 위와 같이 곱셈식의 원리대로 매치를 시켜보면, 각 위치에 대응되는 A와 B의 관계는 K-inversion에서의 K와 같은 증감을 나타내고 있습니다.

(이 때 아랫줄을 뒤집었으므로 윗줄의 A와 아랫줄의 B를 매칭시키고, 나머지는 체크하지 않도록 합시다. 예를 들어 윗줄의 B와 아랫줄의 A 역시 매칭시킨다면 카운트가 2번씩 될 것이기 때문에 하나만 체크하도록 합시다.)

 

 

 

우리는 윗줄의 A와 아랫줄의 B인 쌍만 count하고 싶으므로, 윗줄에서는 A에 대응되는 위치에만 1로, 아랫줄에는 B에 대응되는 위치에만 1로 수를 치환해주도록 합시다.

이제 곱셈식에서 써본대로 똑같이 두 수를 곱해보면, 윗줄에 A, 아랫줄에 B인 쌍에서만 곱했을 때 값이 1이 나오고, 이들을 누적시켜보면 해당 거리를 가지고 있는 쌍들끼리 묶이기 때문에, 각 자릿수에서 K-inversion들의 값이 순서대로 얻어지게 됩니다.

각각의 값들은 10 이상이 될 수도 있지만, 우리는 곱셈식으로 푸는 것이 아닌 벡터의 이산 합성곱으로 계산할 것이기 때문에 문제없습니다.

 

 

 

결론적으로 우리는 문자열의 A = 1, B = 0을 적용한 벡터에, 뒤집은 문자열의 A = 0, B = 1을 적용한 벡터를 이산 합성곱을 해준 뒤, 얻어진 벡터에서 뒤의 N-1개의 성분을 출력해주기만 하면 됩니다.

 

 

#include <bits/stdc++.h>
#define int long long
using namespace std;

const double PI = acos(-1);
typedef complex<double> cpx;

void FFT(vector<cpx> &v, bool inv) {
    int S = v.size();

    for(int i=1, j=0; i<S; i++) {
        int bit = S/2;

        while(j >= bit) {
            j -= bit;
            bit /= 2;
        }
        j += bit;

        if(i < j) swap(v[i], v[j]);
    }

    for(int k=1; k<S; k*=2) {
        double angle = (inv ? PI/k : -PI/k);
        cpx w(cos(angle), sin(angle));

        for(int i=0; i<S; i+=k*2) {
            cpx z(1, 0);

            for(int j=0; j<k; j++) {
                cpx even = v[i+j];
                cpx odd = v[i+j+k];

                v[i+j] = even + z*odd;
                v[i+j+k] = even - z*odd;

                z *= w;
            }
        }
    }

    if(inv)
        for(int i=0; i<S; i++) v[i] /= S;
}

vector<int> mul(vector<int> &v, vector<int> &u) {
    vector<cpx> vc(v.begin(), v.end());
    vector<cpx> uc(u.begin(), u.end());

    int S = 2;
    while(S < v.size() + u.size()) S *= 2;

    vc.resize(S); FFT(vc, false);
    uc.resize(S); FFT(uc, false);

    for(int i=0; i<S; i++) vc[i] *= uc[i];
    FFT(vc, true);

    vector<int> w(S);
    for(int i=0; i<S; i++) w[i] = round(vc[i].real());

    return w;
}

main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL), cout.tie(NULL);

    string str; cin >> str;

    int N = str.length();

    vector<int> v(N), u(N);

    for(int i=0; i<N; i++)
        if(str[i] == 'A') v[i] = 1;

    for(int i=0; i<N; i++)
        if(str[i] == 'B') u[N-1-i] = 1;

    vector<int> w = mul(v, u);

    for(int i=N; i<N*2-1; i++) cout << w[i] << "\n";
}

 

이 포스트에서는 FFT 응용 문제를 다루고 있기 때문에 FFT 코드에 대한 설명은 생략하기로 하고, main 함수 부분만 언급해보자면 문자열을 1 또는 0으로 치환하고, 하나는 뒤집어서 그대로 이산 합성곱을 수행해준 뒤 일부 자릿수만 출력해주면 됩니다.

 

아이디어가 다른 플레티넘1 난이도의 FFT 문제보다 어려워서 그렇지, FFT 코드를 제외한 구현 자체는 별로 할 것이 없는 문제입니다.

 

 


+ 저는 처음에 굳이 int형 벡터로 넘겨야 하는 이유를 잘 모르겠어서 그냥 바로 complex형으로 값을 대입하는 방식으로 풀었는데, 그렇게 하니 풀이 시간이 거의 2배가 나오더군요.

속도 문제 때문에라도 int형 벡터로 한 번 넘겨서 사용하도록 해야겠습니다.

 

그리고 아래와 같이 코드를 작성하면, 마지막에 abs 처리를 하지 않으면 -0이 출력되어 WA를 받을 수 있습니다.

complex 자료형을 사용할 때는 -0이 출력되는 경우를 주의합시다.

 

↓ 풀이 코드는 여기 있습니다. (여전히 정답 처리를 받은 코드이긴 합니다.)

더보기
#include <bits/stdc++.h>
#define int long long
using namespace std;

const double PI = acos(-1);
typedef complex<double> cpx;

void FFT(vector<cpx> &v, bool inv) {
    int S = v.size();

    for(int i=1, j=0; i<S; i++) {
        int bit = S/2;

        while(j >= bit) {
            j -= bit;
            bit /= 2;
        }
        j += bit;

        if(i < j) swap(v[i], v[j]);
    }

    for(int k=1; k<S; k*=2) {
        double angle = (inv ? PI/k : -PI/k);
        cpx w(cos(angle), sin(angle));

        for(int i=0; i<S; i+=k*2) {
            cpx z(1, 0);

            for(int j=0; j<k; j++) {
                cpx even = v[i+j];
                cpx odd = v[i+j+k];

                v[i+j] = even + z*odd;
                v[i+j+k] = even - z*odd;

                z *= w;
            }
        }
    }

    if(inv)
        for(int i=0; i<S; i++) v[i] /= S;
}

void mul(vector<cpx> &v, vector<cpx> &u) {
    int S = 2;
    while(S < v.size() + u.size()) S *= 2;

    v.resize(S); FFT(v, false);
    u.resize(S); FFT(u, false);

    for(int i=0; i<S; i++) v[i] *= u[i];
    FFT(v, true);
}

main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL), cout.tie(NULL);

    string str; cin >> str;

    int N = str.length();

    vector<cpx> v(N), u(N);

    for(int i=0; i<N; i++)
        if(str[i] == 'A') v[i] = cpx(1, 0);

    for(int i=0; i<N; i++)
        if(str[N-i-1] == 'B') u[i] = cpx(1, 0);

    mul(v, u);

    for(int i=N; i<N*2-1; i++) cout << abs(round(v[i].real())) << "\n";
}

 

 

 

반응형