알고리즘/알고리즘 공부 내용 정리

금광 세그먼트 트리 : 2차원 배열의 최대 구간 합 알고리즘 (설명, 예제 코드)

restudy 2022. 7. 10. 19:35
반응형

2차원 배열에서 특정 직사각형 구간을 잡아 얻을 수 있는 구간의 최대 합을 가장 빠른 시간에 구할 수 있는 알고리즘은 무엇일까요?

특정한 점화식을 사용하는 세그먼트 트리를 이용하면 2차원 공간에서의 최대 구간합을 O(N^2 log N)에 구할 수 있다는 것이 잘 알려져 있으며, 이 알고리즘을 사용하는 문제가 정보 올림피아드에 '금광'이라는 이름의 문제로 출제되어 많은 사람들이 이를 금광 세그라고 부릅니다.

 

이 포스트에서는 금광 세그에 대해 공부한 것을 간단하게 설명해보고, 이를 구현해볼 것입니다.

 

 

백준 BOJ 10167번 : 금광

N개의 점이 주어지고, 각 점에 대해 x, y 좌표와 가중치 w가 주어질 때, 특정 직사각형 형태의 구간을 잡아 얻을 수 있는 최대 구간 합을 구하는 문제입니다.

풀이 코드가 시간 안에 통과하도록 풀이하기 위해서는 금광 세그를 구현해야 함을 위에서 이미 언급했으므로, 여기에서는 바로 설명을 다루겠습니다.

 

세그먼트 트리를 이용하면 구간의 합과, 구간의 최댓값O(log N) 시간에 업데이트하거나 구할 수 있습니다.

그런데 이 구간의 합과 몇 가지 변수를 추가하여 점화식을 만들면, 구간의 최대 합에 대한 식을 만들어낼 수 있습니다.

다음 설명에서 구간의 왼쪽, 오른쪽이라는 것은 세그먼트 트리의 노드가 담당하는 구간의 양단을 말하는 것입니다.

 

먼저 사용할 변수 4개는 다음과 같습니다.

- sum : 구간 합

- lsum : 구간의 맨 왼쪽부터 시작하여 잡은 구간의 최대 합

- rsum : 구간의 맨 오른쪽부터 시작하여 (왼쪽으로) 잡은 구간의 최대 합

- maxsum : 구간 합의 최댓값

 

그러면 다음과 같이 노드 a, b를 자식으로 가지는 상위 노드의 4개의 변수에 대한 점화식을 만들 수 있습니다.

- sum = a.sum + b.sum

- lsum = max(a.lsum, a.sum + b.lsum)
- rsum = max(b.rsum, b.sum + a.rsum)

- maxsum = max({a.maxsum, b.maxsum, a.rsum + b.lsum})

 

 

 

점화식에 대한 간단한 설명을 그림으로 첨부합니다.

 

따라서 위와 같이 4개의 변수를 모두 관리할 수 있는 세그먼트 트리를 구성할 수 있음을 확인하였으므로 이 문제를 세그먼트 트리를 이용하여 풀이할 수 있음을 알게 되었습니다.

 

그렇다면 이제 구현을 해봅시다.

위에서 설명한 것은 핵심 아이디어이고, 추가적으로는 좌표 압축과, 하나의 축을 기준으로 스위핑을 하여 나머지 축에 대해 세그먼트 트리에 업데이트 시키는 부분 역시 구현해주어야 합니다.

이에 대한 배경 지식이 없다면 이 문제를 풀어보기 전에 해당 부분에 대한 다른 자료를 참고하시는 것을 권장합니다.

 

 

#include <bits/stdc++.h>
#define int long long
using namespace std;

struct p { int x, y, w; };
vector<p> v;

struct node { int lsum, rsum, sum, maxsum; };
vector<node> tree;

node mer(node a, node b) {
    int lsum = max(a.lsum, a.sum + b.lsum);
    int rsum = max(b.rsum, b.sum + a.rsum);
    int sum = a.sum + b.sum;
    int maxsum = max({a.maxsum, b.maxsum, a.rsum + b.lsum});

    return {lsum, rsum, sum, maxsum};
}

void upd(int n, int b, int e, int idx, int val) {
    if(idx < b || e < idx) return;

    tree[n].lsum += val;
    tree[n].rsum += val;
    tree[n].sum += val;
    tree[n].maxsum += val;

    if(b == e) return;

    upd(n*2, b, (b+e)/2, idx, val);
    upd(n*2 + 1, (b+e)/2 + 1, e, idx, val);

    tree[n] = mer(tree[n*2], tree[n*2 + 1]);
}

node query(int n, int b, int e, int l, int r) {
    if(r < b || e < l) return {INT_MIN, INT_MIN, INT_MIN, INT_MIN};
    if(l <= b && e <= r) return tree[n];

    node ln = query(n*2, b, (b+e)/2, l, r);
    node rn = query(n*2 + 1, (b+e)/2 + 1, e, l, r);

    return mer(ln, rn);
}

main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL), cout.tie(NULL);

    int N; cin >> N;

    v.resize(N+1);
    vector<pair<int, int>> vx(N+1), vy(N+1);

    for(int i=1; i<=N; i++) {
        cin >> v[i].x >> v[i].y >> v[i].w;

        vx[i] = {v[i].x, i};
        vy[i] = {v[i].y, i};
    }

    sort(vx.begin()+1, vx.end());
    sort(vy.begin()+1, vy.end());

    int xcnt = 1, ycnt = 1;

    for(int i=1; i<=N; i++) {
        if(i > 1 && vx[i].first > vx[i-1].first) xcnt++;

        v[vx[i].second].x = xcnt;
    }
    for(int i=1; i<=N; i++) {
        if(i > 1 && vy[i].first > vy[i-1].first) ycnt++;

        v[vy[i].second].y = ycnt;
    }

    vector<vector<int>> u(N+1, vector<int>(N+1));
    vector<vector<pair<int, int>>> uy(N+1);

    for(int i=1; i<=N; i++) {
        u[v[i].x][v[i].y] = v[i].w;
        uy[v[i].y].push_back({v[i].x, v[i].w});
    }

    int ans = INT_MIN;
    for(int i=1; i<=ycnt; i++) {
        tree.clear();
        tree.resize((N+1)*4);

        for(int j=i; j<=ycnt; j++) {
            for(int k=0; k<uy[j].size(); k++)
                upd(1, 1, N, uy[j][k].first, uy[j][k].second);

            ans = max(ans, query(1, 1, N, 1, N).maxsum);
        }
    }

    cout << ans << "\n";
}

 

위와 같이 풀이를 작성해줄 수 있습니다.

하나의 노드가 가지는 4개의 변수 모두 "합"에 관한 변수이므로, 트리를 업데이트 해줄 때 자식 노드로 내려가면서 거치는 노드마다 4개의 변수 모두에 val을 계속 더해주기만 하면 됩니다.

 

 


+ 다음과 같은 부분 점수 풀이도 존재합니다.

정답 풀이는 아니므로 참고만 하시기 바랍니다.

 

wrong sol1 ) 구간 합 (O(N^6)) (시간 초과)

더보기
#include <bits/stdc++.h>
#define int long long
using namespace std;

struct p { int x, y, w; };
vector<p> v;

main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL), cout.tie(NULL);

    int N; cin >> N;

    v.resize(N+1);
    vector<pair<int, int>> vx(N+1), vy(N+1);

    for(int i=1; i<=N; i++) {
        cin >> v[i].x >> v[i].y >> v[i].w;

        vx[i] = {v[i].x, i};
        vy[i] = {v[i].y, i};
    }

    sort(vx.begin()+1, vx.end());
    sort(vy.begin()+1, vy.end());

    int xcnt = 1, ycnt = 1;

    for(int i=1; i<=N; i++) {
        if(i > 1 && vx[i].first > vx[i-1].first) xcnt++;

        v[vx[i].second].x = xcnt;
    }
    for(int i=1; i<=N; i++) {
        if(i > 1 && vy[i].first > vy[i-1].first) ycnt++;

        v[vy[i].second].y = ycnt;
    }

    vector<vector<int>> u(N+1, vector<int>(N+1));
    for(int i=1; i<=N; i++)
        u[v[i].x][v[i].y] = v[i].w;

    int ans = INT_MIN;
    for(int i=1; i<=N; i++)
        for(int j=i; j<=N; j++)
            for(int k=1; k<=N; k++)
                for(int l=k; l<=N; l++) {
                    int sum = 0;

                    for(int m=i; m<=j; m++)
                        for(int n=k; n<=l; n++) sum += u[m][n];

                    ans = max(ans, sum);
                }

    cout << ans << "\n";
}

 

wrong sol 2) 누적 합을 이용한 구간 합 (O(N^4)) (11점)

더보기
#include <bits/stdc++.h>
#define int long long
using namespace std;

struct p { int x, y, w; };
vector<p> v;

main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL), cout.tie(NULL);

    int N; cin >> N;

    v.resize(N+1);
    vector<pair<int, int>> vx(N+1), vy(N+1);

    for(int i=1; i<=N; i++) {
        cin >> v[i].x >> v[i].y >> v[i].w;

        vx[i] = {v[i].x, i};
        vy[i] = {v[i].y, i};
    }

    sort(vx.begin()+1, vx.end());
    sort(vy.begin()+1, vy.end());

    int xcnt = 1, ycnt = 1;

    for(int i=1; i<=N; i++) {
        if(i > 1 && vx[i].first > vx[i-1].first) xcnt++;

        v[vx[i].second].x = xcnt;
    }
    for(int i=1; i<=N; i++) {
        if(i > 1 && vy[i].first > vy[i-1].first) ycnt++;

        v[vy[i].second].y = ycnt;
    }

    vector<vector<int>> u(N+1, vector<int>(N+1));
    for(int i=1; i<=N; i++)
        u[v[i].x][v[i].y] = v[i].w;

    vector<vector<int>> us(N+1, vector<int>(N+1));
    for(int i=1; i<=N; i++)
        for(int j=1; j<=N; j++)
            us[i][j] = us[i-1][j] + us[i][j-1] - us[i-1][j-1] + u[i][j];

    int ans = INT_MIN;
    for(int i=1; i<=N; i++)
        for(int j=i; j<=N; j++)
            for(int k=1; k<=N; k++)
                for(int l=k; l<=N; l++) {
                    int sum = us[j][l] - us[i-1][l] - us[j][k-1] + us[i-1][k-1];

                    ans = max(ans, sum);
                }

    cout << ans << "\n";
}

 

wrong ans 3) DP를 활용한 누적합 (O(N^3)) (채점은 50%까지 가는데 WA 받음)

더보기
#include <bits/stdc++.h>
#define int long long
using namespace std;

struct p { int x, y, w; };
vector<p> v;

main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL), cout.tie(NULL);

    int N; cin >> N;

    v.resize(N+1);
    vector<pair<int, int>> vx(N+1), vy(N+1);

    for(int i=1; i<=N; i++) {
        cin >> v[i].x >> v[i].y >> v[i].w;

        vx[i] = {v[i].x, i};
        vy[i] = {v[i].y, i};
    }

    sort(vx.begin()+1, vx.end());
    sort(vy.begin()+1, vy.end());

    int xcnt = 1, ycnt = 1;

    for(int i=1; i<=N; i++) {
        if(i > 1 && vx[i].first > vx[i-1].first) xcnt++;

        v[vx[i].second].x = xcnt;
    }
    for(int i=1; i<=N; i++) {
        if(i > 1 && vy[i].first > vy[i-1].first) ycnt++;

        v[vy[i].second].y = ycnt;
    }

    vector<vector<int>> u(N+1, vector<int>(N+1));
    for(int i=1; i<=N; i++)
        u[v[i].x][v[i].y] = v[i].w;

    int ans = INT_MIN;
    for(int i=1; i<=N; i++) {
        vector<int> ux(N+1), Max(N+1);

        for(int j=i; j<=N; j++) {
            for(int k=1; k<=N; k++) ux[k] += u[j][k];

            Max[1] = ux[1];

            for(int k=2; k<=N; k++) {
                Max[k] = max(Max[k-1], (int)0) + ux[k];

                ans = max(ans, Max[k]);
            }
        }
    }

    cout << ans << "\n";
}

 

 

 

반응형