고속 푸리에 변환(FFT, Fast Fourier Transform)은 이산 합성곱을 O(N log N) 시간에 계산할 수 있는 알고리즘입니다.
쉽게 말해 두 N차 (또는 그 이하) 다항식의 곱의 계수들을 O(N log N) 시간에 계산할 수 있는 알고리즘이라고 생각하면 됩니다.
개요
N차 실수 벡터에 이산 푸리에 변환(DFT, Discrete Fourier Transform)이라는 조작 과정을 거쳐 특정한 N차 복소수 벡터의 형태로 만들면, 단순 내적 곱으로도 두 식을 곱한 후 DFT한 값과 같은 값이 얻어지게 됩니다.
따라서 O(N) 시간에 두 식을 곱하고, 다시 역 이산 푸리에 변환(IDFT, Inverse Discrete Fourier Transform)의 조작 과정을 거쳐 실수 N차 벡터로 재변환해주면 O(N log N) 시간 복잡도로 두 벡터의 이산 합성곱을 구할 수 있게 되는 것입니다.
→ 요약하면 원래는 O(N^2) 시간에나 가능한 이산 곱을, O(N log N) (DFT) + O(N) (inner product) + O(N log N) (IDFT)으로 도합 O(N log N)에 가능하게 해주는 알고리즘입니다.
(** 이 포스트에서만 벡터 A에 DFT를 수행한 벡터를 편의상 A'으로 정의하겠습니다.)
(** 그림에서는 × 연산으로 표시되어 있는데 이산 합성곱 연산은 통상적으로 * 기호로 표현하는 것이 맞습니다.)
예를 들며 다시 설명하자면, 위의 그림과 같이 A 벡터와 B 벡터가 있다고 가정합시다.
엄밀히는 둘 중 차수가 큰 것이 N차라고 해야겠지만, 편의상 위처럼 둘 다 N차로 나타내기로 하고 DFT 과정을 거치면 A'과 B'이 얻어집니다.
DFT에서는 식을 특수하게 정의하여 잡고 그에 따라 변환시켰기 때문에, A'과 B'의 각 성분을 내적하여 얻은 새로운 백터 C'가 A와 B의 이산곱을 수행한 벡터에 DFT를 수행하여 얻은 (A * B)'와 같게 됩니다.
이제 이 C'을 원래 벡터로 거치는 과정인, DFT의 역변환 과정 즉 IDFT를 수행하여 C를 얻을 수 있게 되는 것입니다.
그럼 이제 자세한 과정을 살펴봅시다.
(위의 개요보다는 자세하게는 적겠지만, 모든 과정을 완벽하게 이해하고 적는 것이 아니기 때문에 설명의 구체성이 부족할 수 있으며 부족한 부분이 있으면 지적 바랍니다.)
DFT
먼저 DFT 과정을 살펴봅시다.
위의 식 ①과 식 ②를 정의합시다. (이 정의는 FFT(DFT, IDFT) 과정을 수행 가능하게 만들기 위한 정의이며, 이 식을 어떻게 발견했는지에 대해서는 잘 모르겠습니다. 하지만 신기하게도 위의 정의에 따라 FFT의 모든 과정이 맞아떨어집니다.)
위의 식이 시그마 형태로 나와있기 때문에, N = 4인 경우에 대해서만 예시를 간단히 적어보면, 위에서 정의한 식 ②에 따라 위와 같이 A_4'을 적어볼 수 있습니다.
위의 식은 A_4에 4 x 4 크기의 행렬, 즉 N x N 행렬이 곱해졌으므로 O(N^2)의 시간이 소요될 것으로 생각되지만, 홀수차항과 짝수차항으로 분리하여 식을 위와 같이 나눠서 나타내어보면, f_e와 f_o를 찾을 수 있고 각각의 식에 대해 재귀적으로 이 과정을 반복하면, Divide and Conquer 기법의 원리에 따라 O(N log N)시간에 DFT를 수행할 수 있게 됩니다.
(이 역시 위에서의 식 ①과 식 ②의 정의에 따라 가능하게 된 것입니다.)
곱 연산
이제 DFT가 완료된 두 식을 곱합니다.
위에서도 언급했듯, DFT가 완료된 식은 각 성분끼리만 곱해도 두 벡터의 이산곱을 DFT한 것과 같습니다.
따라서 그냥 O(N) 시간에 N개의 항을 끼리끼리 곱해주기만 하면 됩니다.
IDFT
마지막으로 IDFT를 수행해주어야 합니다.
우리는 A'과 B'의 inner product를 수행하여 C' 벡터 하나만을 가지고 있습니다.
역연산의 경우, 위에서 곱해준 행렬 [w_ij]의 역행렬을 곱해주면 되며, 이 역행렬의 경우 이미 계산과 일반화가 모두 잘 되어 있습니다.
우리는 그 결과값을 곱해주기면 하면 되는데, 그것은 식 ③과 같습니다.
정리된 식을 곱해주기면 하면 IDFT가 끝나며 최종적으로 A와 B의 이산 곱 C가 얻어지게 됩니다.
알고리즘 구현 코드 (속도 느린 코드)
#include <bits/stdc++.h>
#define int long long
using namespace std;
const double PI = acos(-1);
typedef complex<double> cpx;
void FFT(vector<cpx> &v, cpx w) {
int n = v.size();
if(n == 1) return;
vector<cpx> even(n/2), odd(n/2);
for(int i=0; i<n; i++) {
if(i % 2 == 0) even[i/2] = v[i];
else odd[i/2] = v[i];
}
FFT(even, w*w);
FFT(odd, w*w);
cpx z(1, 0);
for(int i=0; i<n/2; i++) {
v[i] = even[i] + z*odd[i];
v[i + n/2] = even[i] - z*odd[i];
z *= w;
}
}
vector<int> multiply(vector<int> v, vector<int> u) {
vector<int> w(v.size() + u.size() - 1);
int n = 1;
while(n <= v.size() || n <= u.size()) n *= 2;
n *= 2;
v.resize(n);
u.resize(n);
vector<cpx> v_(n), u_(n);
for(int i=0; i<n; i++) {
v_[i] = cpx(v[i], 0);
u_[i] = cpx(u[i], 0);
}
cpx unit(cos(2*PI/n), sin(2*PI/n));
FFT(v_, unit);
FFT(u_, unit);
vector<cpx> w_(n);
for(int i=0; i<n; i++) w_[i] = v_[i] * u_[i];
FFT(w_, cpx(1, 0)/unit);
for(int i=0; i<n; i++) w_[i] /= cpx(n, 0);
for(int i=0; i<w.size(); i++) w[i] = round(w_[i].real());
return w;
}
main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL), cout.tie(NULL);
vector<int> v = {1, 2};
vector<int> u = {1, 3};
vector<int> w = multiply(v, u);
for(int i=0; i<w.size(); i++) cout << w[i] << " ";
}
(** 주의 : 아래에 더 빠른 코드가 있으니 필요하신 분은 아래쪽까지 봐주세요.)
위의 FFT 알고리즘에 대한 설명을 직접 C++ 코드로 구현하면 위와 같이 됩니다.
벡터 a에는 {1, 2}가, 벡터 b에는 {1, 3}이 들어있고 이산 곱을 수행해주면 {1x1, 1x3 + 1x2, 2x3} = {1, 5, 6}이 얻어지게 됩니다.
위의 코드를 직접 실행해보아도 "1 5 6"이 출력됨을 확인할 수 있습니다.
문제를 풀이할 때는 위의 코드에서 a와 b의 값을 적절히 입력받아 multiply 함수로 이산 곱을 수행하고 c를 다루어 답을 구하면 됩니다.
더 빠른 FFT 알고리즘 (재귀 → 비재귀)
그런데 위의 FFT 알고리즘보다 약 2.5배 정도 빠른 알고리즘을 구현할 수 있습니다.
위에서 구현한 알고리즘은 재귀 형식을 사용하여 DFT 또는 IDFT가 이루어집니다.
그런데 이 재귀 방식을 비재귀 방식으로 풀고, 메모리 낭비 등의 단점을 개선하여 더 효율적인 알고리즘을 구현할 수 있습니다.
이에 대한 원리는 이 링크에 잘 정리되어 있는 것으로 보이며, 여기에서는 코드만 정리하도록 하겠습니다.
#include <bits/stdc++.h>
#define int long long
using namespace std;
const double PI = acos(-1);
typedef complex<double> cpx;
void FFT(vector<cpx> &v, bool inv) {
int S = v.size();
for(int i=1, j=0; i<S; i++) {
int bit = S/2;
while(j >= bit) {
j -= bit;
bit /= 2;
}
j += bit;
if(i < j) swap(v[i], v[j]);
}
for(int k=1; k<S; k*=2) {
double angle = (inv ? PI/k : -PI/k);
cpx w(cos(angle), sin(angle));
for(int i=0; i<S; i+=k*2) {
cpx z(1, 0);
for(int j=0; j<k; j++) {
cpx even = v[i+j];
cpx odd = v[i+j+k];
v[i+j] = even + z*odd;
v[i+j+k] = even - z*odd;
z *= w;
}
}
}
if(inv)
for(int i=0; i<S; i++) v[i] /= S;
}
vector<int> mul(vector<int> &v, vector<int> &u) {
vector<cpx> vc(v.begin(), v.end());
vector<cpx> uc(u.begin(), u.end());
int S = 2;
while(S < v.size() + u.size()) S *= 2;
vc.resize(S); FFT(vc, false);
uc.resize(S); FFT(uc, false);
for(int i=0; i<S; i++) vc[i] *= uc[i];
FFT(vc, true);
vector<int> w(S);
for(int i=0; i<S; i++) w[i] = round(vc[i].real());
return w;
}
main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL), cout.tie(NULL);
int N, M; cin >> N >> M;
vector<int> v(N), u(M);
for(int i=0; i<N; i++) cin >> v[i];
for(int i=0; i<M; i++) cin >> u[i];
vector<int> w = mul(v, u);
for(int i=0; i<v.size()+u.size()-1; i++) cout << w[i] << " ";
cout << "\n";
}
(수정) 입출력부와 벡터 크기와 관련하여 약간 수정을 거쳤습니다. complex 벡터에 바로 값을 넣어서 계산 후 반환 받으면, round를 적용했을 때 가끔 0이 아닌 -0이 출력되는 경우가 있습니다. 그래서 abs를 사용해주어야 하는데, 그보다는 위와 같이 사용하기 편하게 int형 벡터에 값을 넣어서 사용하는 것을 추천드립니다. (이 때 -0 문제 역시 벡터에서 벡터로 값을 옮기면서 0으로 바뀌어 들어가게 됩니다.)
벡터 v의 크기 N과 벡터 u의 크기 M을 순서대로 입력하고, 각 벡터의 원소들을 입력하면 이산 합성곱이 수행된 벡터의 원소들이 공백으로 구분되어 출력됩니다.
예제 풀이
이렇게 마무리하기에는 FFT의 활용에 대한 이해가 부족할 수 있으니, 예제를 하나 풀이해봅시다.
백준에서 가장 대표적인 FFT 알고리즘 활용 문제입니다.
문제를 번역/요약하여 필요한 부분만 설명하자면, N개의 수와 M개 수가 주어질 때, N개의 수들 중에서 1개 이상 2개 이하의 수를 더하여 만들 수 있는 수의 종류가 M개의 수 중 몇 가지인지를 묻는 문제입니다.
이 문제는 0과 1의 곱의 원리와, FFT를 활용하여 풀 수 있습니다.
예를 들어 N개의 수가 1, 3, 4라고 가정하면 벡터 A와 B를 둘 다 (0, 1, 0, 1, 1)로 둡니다. (맨 앞은 a_0이므로 0에 대응되는 값)
그 다음 A와 B의 이산 곱을 구하면 (0, 1, 0, 1, 1) * (0, 1, 0, 1, 1) = (0, 0, 1, 0, 2, 2, 1, 2, 1)이 되므로 우리는 N에서 두 개의 수를 더하여 만들 수 있는 값은 1 이상의 값을 가지는 칸에 해당하는 2, 4, 5, 6, 7, 8임을 알 수 있습니다.
(그리고 여기서 각 수가 해당하는 값은 두 수를 뽑아 더해 이 수를 만들 수 있는 경우의 수에 해당합니다.)
이러한 원리를 활용하여 문제를 풀어줄 수 있습니다.
그리고 주의해야 할 점은, 굳이 두 수를 더하지 않고 한 수만 가지고도 값을 만들 수 있으므로 A_i > 0 이거나 (A x B)_i > 0이면 i는 count 해주어야 한다는 것입니다. (즉, A_i > 0인지도 검사를 해줘야 함)
#include <bits/stdc++.h>
using namespace std;
const double PI = acos(-1);
typedef complex<double> cpx;
void FFT(vector<cpx> &f, bool inv) {
int N = f.size();
for(int i=1, j=0; i<N; i++) {
int bit = N/2;
while(j >= bit) {
j -= bit;
bit /= 2;
}
j += bit;
if(i < j) swap(f[i], f[j]);
}
for(int k=1; k<N; k*=2) {
double angle = (inv ? PI/k : -PI/k);
cpx dir(cos(angle), sin(angle));
for(int i=0; i<N; i+=k*2) {
cpx unit(1, 0);
for(int j=0; j<k; j++) {
cpx u = f[i+j];
cpx v = f[i+j+k] * unit;
f[i+j] = u + v;
f[i+j+k] = u - v;
unit *= dir;
}
}
}
if(inv)
for(int i=0; i<N; i++) f[i] /= N;
}
vector<cpx> multiply(vector<cpx> &a, vector<cpx> &b) {
int N = 1;
while(N < a.size() + b.size()) N *= 2;
a.resize(N); FFT(a, false);
b.resize(N); FFT(b, false);
vector<cpx> c(N);
for(int i=0; i<N; i++) c[i] = a[i] * b[i];
FFT(c, true);
return c;
}
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL), cout.tie(NULL);
vector<cpx> a(200001), b(200001);
vector<int> dist(200001);
int N; cin >> N;
while(N--) {
int i; cin >> i;
a[i] = b[i] = cpx(1, 0);
dist[i] = 1;
}
vector<cpx> c = multiply(a, b);
int M; cin >> M;
int cnt = 0;
while(M--) {
int i; cin >> i;
if(dist[i] > 0 || round(c[i].real()) > 0) cnt++;
}
cout << cnt;
}
설명대로 코드를 구현하면 위와 같이 되며, 기존 코드에서 크게 변하지 않고 main 함수만 조금 바뀌어 있음을 확인할 수 있습니다.
FFT 실수 오차 줄이기
FFT 알고리즘의 코드만 보아도 알 수 있듯, cpx(cos(angle), sin(angle)) 값을 거듭제곱하여 곱할 수록 당연히 오차가 발생할 가능성이 높아지게 됩니다.
예를 들어 다음과 같은 문제에서 위에서 작성한 FFT 코드를 활용하여 문제를 풀이하면 WA를 받게 됩니다.
그렇다면 어떻게 해야 오차를 최소화 할 수 있을까요?
바로 거듭제곱의 형식으로 값을 구해주는 것이 아니라, angle을 직접 j배 하여 값을 구해주면 됩니다.
따라서 위의 빠른 FFT 코드에서 f_even, f_odd를 구하는 부분을 다음과 같이 수정하면 오차를 최소화시킬 수 있습니다.
for(int k=1; k<S; k*=2) {
double angle = (inv ? PI/k : -PI/k);
for(int i=0; i<S; i+=k*2) {
for(int j=0; j<k; j++) {
cpx even = v[i+j], odd = v[i+j+k];
cpx w = cpx(cos(angle*j), sin(angle*j));
v[i+j] = even + w*odd;
v[i+j+k] = even - w*odd;
}
}
}
단, 이 코드의 문제점은 당연하게도 연산 횟수가 큰 폭으로 증가하므로, 풀이 시간이 거의 2.5배 정도나 걸리게 됩니다.
위에서 작성한 FFT의 느린 코드와 속도가 비슷하고, 따라서 몇몇 문제들은 시간 초과에 걸릴 수 있습니다.
따라서 일단은 기존의 빠른 FFT 코드로 풀이를 작성하여 제출하되, 모든 반례를 처리했다고 생각됨에도 WA를 받는다면 실수 오차를 의심해보고 위의 코드로 대체하여 풀이할 수 있습니다.
그 외 생각
FFT 알고리즘을 코드만 따라서 작성해보고 그대로 활용하는 것은 공부의 본질에 맞지 않는다고 생각하여 "대충 이러한 흐름에 따라서 가능한 것이다~"의 느낌으로 글을 정리해보려고 했는데, 생각보다 부족한 부분이 많은 것 같습니다. 틀린 부분이나 보완해야 할 부분은 언제든 지적해주시면 감사하게 생각하겠습니다.
'알고리즘 > 알고리즘 공부 내용 정리' 카테고리의 다른 글
220710 PS 일기 : 구간 합 나머지 O(N)에 구하기, 실수 오차 잡는 법 (임의 정밀도) 등 (0) | 2022.07.09 |
---|---|
알고리즘 문제(PS) 풀 때 도움되는 정보들 (0) | 2022.07.05 |
[C++ 알고리즘] 느리게 갱신되는 세그먼트 트리 (Lazy Propagation of Segment Tree) (0) | 2022.06.23 |
백준(BOJ) 세그먼트 트리 문제 풀이 모음 (Segment Tree) (5) | 2022.06.20 |
C++ 제곱근 분할법 알고리즘 (Sqaure-root Decomposition, 예제 풀이 포함) (0) | 2022.06.11 |