알고리즘/알고리즘 공부 내용 정리

12월 알고리즘 공부 : 퍼시스턴트 세그먼트 트리, 머지 소트 트리, 오일러 회로 등

restudy 2022. 12. 14. 22:03
반응형

백준 BOJ 11012번 : Egg

문제 난이도 : Platinum II

알고리즘 분류 : 퍼시스턴트 세그먼트 트리

 

2차원 좌표계에서 점들의 좌표가 주어질 때, 직사각형 영역이 쿼리로 주어지면 해당 영역에 점이 몇 개가 있는지를 구하는 문제이다. (0 ≤ x, y ≤ 10^5, Q ≤ 50,000)

 

1차원 좌표계였다면 일반적인 세그먼트 트리로 해결이 가능한 것은 물론이고 값의 갱신이 없으므로 누적 합으로도 쉽게 계산이 가능하다.

2차원 좌표계 + 작은 범위였다면 2차원 배열의 누적 합으로 쉽게 계산이 가능하다.

 

이 문제에서는 세그먼트 트리 또는 2차원 배열을 선언하면 메모리 초과가 발생하게 되는데 이를 어떻게 해결할 수 있을까?

이러한 문제는 퍼시스턴트 세그먼트 트리(Persistent Segment Tree, PST)로 풀이할 수 있다.

 

퍼시스턴트 세그먼트 트리는 노드의 갱신 기록을 저장하고 있는 트리이다.

예를 들어 0번 트리에서 어떤 노드가 갱신되면, 1번 트리를 만들고 갱신되는 노드들만 1번 트리에 노드를 만들고 0번 트리와 연결시킨다.

이런 식으로 해나가면 하나의 새로운 트리에 log N개 정도의 노드만 새로 갱신하면 되므로, M개의 갱신에 대해서도 M log N개의 노드만 저장하면 되므로 메모리를 효율적으로 아낄 수 있다.

 

대신 구현이 조금 더 복잡한데, 아래의 코드를 참고하자.

이 문제에서는 x값에 하나의 트리가 대응되게 하였으며, 각 x값에 해당하는 y값들을 x번 트리의 y번 노드에 대응되도록 하여 데이터를 저장하게 구현하였다.

이후는 누적 합을 계산하는 것과 같은 원리로 구할 수 있다.

주의할 점은 좌표가 0 "이상" 10^5 "이하"로 들어올 수 있으므로 이 처리를 잘해주어야 한다. (하나의 트리에서 값을 10^5 + 1개 이상 담당하도록 해야함)

 

참고로 퍼시스턴트 세그먼트 트리는 인덱스를 이용하여 구현하는 방법이 있고 포인터를 이용하여 구현하는 방법이 있다.

백준 BOJ 11012 Egg 문제의 풀이를 구글링해보면 대부분 인덱스 기반으로 구현되어있을 것인데, 아래는 포인터를 이용하여 구현하는 방법이다.

 

 

더보기
#include <bits/stdc++.h>
#define int long long
using namespace std;

const int MAX = 1e5 + 1;

struct node {
    node *l, *r; int x;
    node() { l = r = NULL, x = 0; }
};

node *add(node *cur, int b, int e, int idx, int val) {
    if(cur == NULL) cur = new node();

    if(idx < b || e < idx) return cur;

    if(b == e) {
        node *ret = new node();

        ret->x = cur->x + 1;

        return ret;
    }

    node *ret = new node();

    node *ln = add(cur->l, b, (b+e)/2, idx, val);
    node *rn = add(cur->r, (b+e)/2 + 1, e, idx, val);

    ret->l = ln, ret->r = rn, ret->x = ln->x + rn->x;

    return ret;
}

int sum(node *cur, int b, int e, int l, int r) {
    if(cur == NULL) return 0;

    if(r < b || e < l) return 0;
    if(l <= b && e <= r) return cur->x;

    return sum(cur->l, b, (b+e)/2, l, r) + sum(cur->r, (b+e)/2 + 1, e, l, r);
}

main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL), cout.tie(NULL);

    int T; cin >> T;

    while(T--) {
        int N, M; cin >> N >> M;

        vector<vector<int>> v(MAX+1);

        for(int i=0; i<N; i++) {
            int x, y; cin >> x >> y;

            v[x+1].push_back(y+1);
        }

        node *u[MAX+1];
        u[0] = new node();

        for(int i=1; i<=MAX; i++) {
            u[i] = new node();

            u[i]->l = u[i-1]->l, u[i]->r = u[i-1]->r, u[i]->x = u[i-1]->x;

            for(int j=0; j<v[i].size(); j++)
                u[i] = add(u[i], 1, MAX, v[i][j], 1);
        }

        int ans = 0;

        while(M--) {
            int a, b, c, d; cin >> a >> b >> c >> d;

            ans += sum(u[b+1], 1, MAX, c+1, d+1)
                   - sum(u[a], 1, MAX, c+1, d+1);
        }

        cout << ans << "\n";
    }
}

 

 

백준 BOJ 13537번 : 수열과 쿼리 1

문제 난이도 : Platinum III

알고리즘 분류 : 머지 소트 트리

 

길이 10만 이하의 수열이 주어지고, 10만 개 이하의 쿼리에 대해 구간 [i, j]가 주어지면 해당 구간에서 k보다 큰 원소의 개수를 구하는 문제이다.

 

세그먼트 트리와 동일하게 각 구간을 담당하는 노드가 있고, 각 노드에 담당 구간 원소들이 정렬된 벡터를 가지고 있다면 log 시간에 해결이 가능할 것이다.

그리고 이렇게 구현을 하는 경우 트리의 높이는 log N 정도이고, 하나의 층에 총 N개의 원소가 있으므로 공간복잡도 역시 O(N log N)으로 메모리 초과를 발생시키지 않는다.

 

이것이 머지 소트 트리(Merge Sort Tree)이다.

구현이 문제인데, 이것은 merge 함수를 이용하여 생각보다 간단하게 구현이 가능하다.

자세한 구현은 아래의 코드를 참고하면 도움이 될 것이다.

 

 

더보기
#include <bits/stdc++.h>
#define int long long
using namespace std;

struct mergeSortTree {
    vector<vector<int>> v;

    void init(int n, int b, int e, vector<int> &u) {
        if(b == e) {
            v[n].push_back(u[b-1]);
            return;
        }

        init(n*2, b, (b+e)/2, u);
        init(n*2+1, (b+e)/2+1, e, u);

        v[n].resize(v[n*2].size() + v[n*2+1].size());
        merge(v[n*2].begin(), v[n*2].end(), v[n*2+1].begin(), v[n*2+1].end(), v[n].begin());
    }

    int gt(int n, int b, int e, int l, int r, int x) {
        if(r < b || e < l) return 0;

        if(l <= b && e <= r)
            return v[n].end() - upper_bound(v[n].begin(), v[n].end(), x);

        return gt(n*2, b, (b+e)/2, l, r, x) + gt(n*2+1, (b+e)/2+1, e, l, r, x);
    }
};

main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL), cout.tie(NULL);

    mergeSortTree f;
    f.v.resize(1<<18);

    int N; cin >> N;

    vector<int> v(N);
    for(int i=0; i<N; i++) cin >> v[i];

    f.init(1, 1, N, v);

    int M; cin >> M;

    while(M--) {
        int a, b, c; cin >> a >> b >> c;

        int ans = f.gt(1, 1, N, a, b, c);

        cout << ans << "\n";
    }
}

 

 

백준 BOJ 13544번 : 수열과 쿼리 3

문제 난이도 : Platinum III

알고리즘 분류 : 머지 소트 트리

 

위의 문제와 거의 동일하다.

코드 역시 거의 비슷하다. 아래에 첨부해두었다.

 

 

더보기
#include <bits/stdc++.h>
#define int long long
using namespace std;

struct mergeSortTree {
    vector<vector<int>> v;

    void init(int n, int b, int e, vector<int> &u) {
        if(b == e) {
            v[n].push_back(u[b-1]);
            return;
        }

        init(n*2, b, (b+e)/2, u);
        init(n*2+1, (b+e)/2+1, e, u);

        v[n].resize(v[n*2].size() + v[n*2+1].size());
        merge(v[n*2].begin(), v[n*2].end(), v[n*2+1].begin(), v[n*2+1].end(), v[n].begin());
    }

    int gt(int n, int b, int e, int l, int r, int x) {
        if(r < b || e < l) return 0;

        if(l <= b && e <= r)
            return v[n].end() - upper_bound(v[n].begin(), v[n].end(), x);

        return gt(n*2, b, (b+e)/2, l, r, x) + gt(n*2+1, (b+e)/2+1, e, l, r, x);
    }
};

main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL), cout.tie(NULL);

    mergeSortTree f;
    f.v.resize(1<<18);

    int N; cin >> N;

    vector<int> v(N);
    for(int i=0; i<N; i++) cin >> v[i];

    f.init(1, 1, N, v);

    int M; cin >> M;

    int ans = 0;

    while(M--) {
        int a, b, c; cin >> a >> b >> c;

        a ^= ans, b ^= ans, c ^= ans;

        ans = f.gt(1, 1, N, a, b, c);

        cout << ans << "\n";
    }
}

 

 

백준 BOJ 7469번 : K번째 수

문제 난이도 : Platinum II

알고리즘 분류 : 머지 소트 트리

 

역시 쿼리 문제인데, 주어진 구간의 원소들을 정렬했을 때 K번째로 작은 수를 구하는 쿼리를 처리하는 문제이다.

머지 소트 트리로 해결이 가능하다.

 

방법은 이분 탐색을 활용하는 것이다.

적당한 x를 잡아서 x 이하인 수의 개수를 구하고, 이 x 이하인 수가 k개인 가장 작은 x를 이분 탐색으로 구해주면 된다.

 

구현은 아래에 있으며, ord 함수를 중점적으로 보면 될 것이다.

 

 

더보기
#include <bits/stdc++.h>
#define int long long
using namespace std;

struct mergeSortTree {
    vector<vector<int>> v;

    void init(int n, int b, int e, vector<int> &u) {
        if(b == e) {
            v[n].push_back(u[b-1]);
            return;
        }

        init(n*2, b, (b+e)/2, u);
        init(n*2+1, (b+e)/2+1, e, u);

        v[n].resize(v[n*2].size() + v[n*2+1].size());
        merge(v[n*2].begin(), v[n*2].end(), v[n*2+1].begin(), v[n*2+1].end(), v[n].begin());
    }

    int ord(int n, int b, int e, int l, int r, int x) {
        if(r < b || e < l) return 0;

        if(l <= b && e <= r)
            return upper_bound(v[n].begin(), v[n].end(), x) - v[n].begin();

        return ord(n*2, b, (b+e)/2, l, r, x) + ord(n*2+1, (b+e)/2+1, e, l, r, x);
    }
};

main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL), cout.tie(NULL);

    mergeSortTree f;
    f.v.resize(1<<18);

    int N, M; cin >> N >> M;

    vector<int> v(N);
    for(int i=0; i<N; i++) cin >> v[i];

    f.init(1, 1, N, v);

    while(M--) {
        int a, b, c; cin >> a >> b >> c;

        int l = -1e9, r = 1e9, ans = 1e9;

        while(l <= r) {
            int m = (l + r) / 2;

            int x = f.ord(1, 1, N, a, b, m);

            if(x >= c) {
                ans = min(ans, m);
                r = m - 1;
            }
            else l = m + 1;
        }

        cout << ans << "\n";
    }
}

 

 

백준 BOJ 13310번 : 먼 별

문제 난이도 : Diamond V

알고리즘 분류 : 회전하는 캘리퍼스, 삼분 탐색

 

3만 개 이하의 2차원 좌표 상의 점과 하루 이동하는 거리가 주어지고, 10^7 이하의 일 수가 주어질 때, 주어진 기간 내에서 가장 먼 두 별의 거리가 가장 가까운 날짜를 구하고, 그 때의 거리를 구하는 문제이다.

 

우리는 우선 특정 날짜에서 별의 분포에 대해 가장 먼 두 별의 거리를 구하는 방법을 생각해보아야 한다.

이것은 볼록 껍질 (Convex Hull) + 회전하는 캘리퍼스 (Rotating Calipers) 알고리즘으로 구현할 수 있다.

 

이제 날짜에 따른 별들의 거리 분포에 대해 생각해보자.

여기선 엄밀한 증명은 하지 않겠지만, 직관적으로 보았을 때 별들의 최대 거리가 시간이 지남에 따라 증가하거나, 감소하거나, 또는 감소하다가 증가하는 세 가지 경우가 전부이다.

이렇게 최대 1개의 극값을 갖는 경우 삼분 탐색으로 답을 찾을 수 있다.

 

이 블로그에 컨벡스 헐, 회전하는 캘리퍼스, 삼분 탐색 알고리즘을 이전에 모두 정리해두었으므로 여기에서는 풀이 코드만 첨부한다.

두 삼분점에서의 가장 먼 두 별의 거리를 가지고 범위를 적절히 좁혀나간 뒤, 좁혀진 범위 내에서의 최소 거리를 구해주면 된다.

 

참고로 삼분 탐색 문제를 많이 안 풀어봤으면 부등호 범위가 애매할 수 있다. (나도 오랜만에 풀어서 잠시 헷갈렸다.)

양단 좌표가 l, r이고 삼분점이 각각 m1, m2라고 할 때 이 문제에서는 f(m1) == f(m2)인 경우 r = m2로 범위를 좁혀야 되는데, 그 이유는 가장 먼 두 별의 거리가 최소인 날짜가 여러 개일 경우 가장 빠른 촬영일을 출력하면 된다고 하였으므로 오른쪽 끝을 왼쪽으로 붙여주면 된다.

만약 이 처리를 하지 않는다면 삼분 탐색이 끝나도 f(x) 값이 동일한 날이 매우 많아 l ~ r 범위가 커진다면 시간 초과가 발생한다.

 

 

더보기
#include <bits/stdc++.h>
#define int long long
using namespace std;

struct P { int x, y; };

P operator-(P a, P b) {
    P c;
    c.x = a.x - b.x;
    c.y = a.y - b.y;
    return c;
}

int N, M;
vector<P> v, u, w;

int ccw(P a, P b, P c) {
    return a.x * (b.y - c.y) + b.x * (c.y - a.y) + c.x * (a.y - b.y);
}

bool cmp(P &a, P &b) {
    int x = ccw(v[0], a, b);

    if(x != 0) return x > 0;
    else if(a.y != b.y) return a.y < b.y;
    else return a.x < b.x;
}

int f(int m) {
    v.clear(); v.resize(N);

    for(int i=0; i<N; i++) {
        v[i].x = u[i].x + w[i].x * m;
        v[i].y = u[i].y + w[i].y * m;
    }

    for(int i=1; i<N; i++)
            if(v[i].y < v[0].y || (v[i].y == v[0].y && v[i].x < v[0].x)) swap(v[i], v[0]);

    sort(v.begin()+1, v.end(), cmp);

    stack<P> s;

    s.push(v[0]);
    s.push(v[1]);

    for(int i=2; i<N; i++) {
        while(s.size() >= 2) {
            P a = s.top(); s.pop();
            P b = s.top();

            if(ccw(b, a, v[i]) > 0) {
                s.push(a);
                break;
            }
        }
        s.push(v[i]);
    }

    vector<P> u(s.size());
    while(!s.empty()) {
        u[s.size()-1] = s.top();
        s.pop();
    }

    int l = 0, r = 0;
    for(int i=0; i<u.size(); i++) {
        if(u[i].x < u[l].x) l = i;
        if(u[i].x > u[r].x) r = i;
    }

    int ret = pow(u[l].x - u[r].x, 2) + pow(u[l].y - u[r].y, 2);
    P o = {0, 0};

    for(int i=0; i<u.size(); i++) {
        int nl = (l+1) % u.size();
        int nr = (r+1) % u.size();

        if(ccw(o, u[nl] - u[l], u[r] - u[nr]) > 0) l = nl;
        else r = nr;

        ret = max(ret, (u[l].x - u[r].x) * (u[l].x - u[r].x)
                        + (u[l].y - u[r].y) * (u[l].y - u[r].y));
    }

    return ret;
}

signed main() {
    ios_base::sync_with_stdio(0), cin.tie(0);

    cin >> N >> M;

    u.resize(N), w.resize(N);

    for(int i=0; i<N; i++)
        cin >> u[i].x >> u[i].y >> w[i].x >> w[i].y;

    int l = 0, r = M;

    while(l+3 <= r) {
        int m1 = (l*2 + r) / 3;
        int m2 = (l + r*2) / 3;

        if(f(m1) > f(m2)) l = m1;
        else r = m2;
    }

    int dis = LLONG_MAX, day;

    for(int i=l; i<=r; i++) {
        int val = f(i);

        if(val < dis) {
            dis = val;
            day = i;
        }
    }

    cout << day << "\n";
    cout << dis << "\n";
}

 

 

백준 BOJ 1199번 : 오일러 회로

문제 난이도 : Platinum IV

알고리즘 분류 : 오일러 경로

 

그래프에서의 N x N 크기의 인접 행렬이 주어질 때, 오일러 회로대로 방문하는 정점들을 순서대로 출력하는 문제이다.

 

오일러 회로그래프에서 모든 간선을 지나고 다시 출발점으로 돌아오는 "회로"를 말한다.

 

오일러 회로 자체를 구현하는 것은 다음과 같은 아이디어로 구현한다.

회로를 따라 이동하면 다시 원래 자리로 돌아오므로 여러 개의 회로를 쪼개어 찾은 뒤 이 경로들을 합쳐주면 된다.

이것은 DFS를 재귀적으로 구현하면 된다.

 

문제는 이 문제에 한해서는 이렇게 구현된 Hierholzer 알고리즘은 O(VE)의 시간복잡도를 가지는데, 여기에서는 O(E log V) 이하로 줄여야 통과가 가능하다.

따라서 나의 경우 세그먼트 트리를 활용하여 꼭짓점들을 log 시간에 인접 리스트에서 탐색하여 추가하거나 제거할 수 있도록 구현해주었다.

 

 

더보기
#include <bits/stdc++.h>
#define int long long
using namespace std;

struct segmentTree {
    vector<int> v;

    void upd(int n, int b, int e, int idx, int val) {
        if(b == e) {
            v[n] += val;
            return;
        }

        if(idx <= (b+e)/2) upd(n*2, b, (b+e)/2, idx, val);
        else upd(n*2 + 1, (b+e)/2 + 1, e, idx, val);

        v[n] = v[n*2] + v[n*2 + 1];
    }

    int sum(int n, int b, int e, int l, int r) {
        if(r < b || e < l) return 0;
        if(l <= b && e <= r) return v[n];

        return sum(n*2, b, (b+e)/2, l, r)
                + sum(n*2 + 1, (b+e)/2 + 1, e, l, r);
    }

    int kth(int n, int b, int e, int ran) {
        if(b == e) return b;

        if(ran <= v[n*2]) return kth(n*2, b, (b+e)/2, ran);
        else return kth(n*2+1, (b+e)/2+1, e, ran-v[n*2]);
    }
};

struct EulerianPath {
    segmentTree adj[1001];

    int N, start = 0;
    stack<int> s;

    void init() {
        cin >> N;

        for(int i=1; i<=N; i++) {
            adj[i].v.resize(N*4);

            for(int j=1; j<=N; j++) {
                int x; cin >> x;

                adj[i].upd(1, 1, N, j, x);
            }
        }
    }

    bool exist() {
        for(int i=1; i<=N; i++) {
            int cnt = adj[i].sum(1, 1, N, 1, N);

            if(cnt % 2 == 1) return false;

            if(start == 0 && cnt > 0) start = i;
        }

        return true;
    }

    void dfs(int x) {
        while(adj[x].sum(1, 1, N, 1, N) > 0) {
            int y = adj[x].kth(1, 1, N, 1);

            adj[x].upd(1, 1, N, y, -1);
            adj[y].upd(1, 1, N, x, -1);

            dfs(y);
        }

        s.push(x);
    }
};

main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL), cout.tie(NULL);

    EulerianPath f; f.init();

    if(!f.exist()) {
        cout << -1 << "\n";
        return 0;
    }

    f.dfs(f.start);

    while(!f.s.empty()) {
        cout << f.s.top() << " ";
        f.s.pop();
    }
    cout << "\n";
}

 

 

백준 BOJ 9463번 : 순열 그래프

문제 난이도 : Platinum V

알고리즘 분류 : 세그먼트 트리

 

 

 

 

이 문제는 기본적인 세그먼트 트리 문제인데 워낙 웰노운이라 여기에 간단히만 정리한다.

위와 같이 2개의 순열이 주어지고, 위 아래에서 같은 수끼리 연결했을 때 교차점의 수를 구하는 문제이다.

 

핵심은 넘버링 부분인데, 웰노운이긴 하지만 깔끔하기도 하고 기억해둘만 해서 아래에 정리한다.

 

먼저 윗줄에서 i번째로 입력받은 x에 대해 v[x] = i로 저장해준다.

그 다음 아랫줄에서 i번째로 입력받은 x에 대해 u[i] = v[x]로 저장해준다.

그러면 u[i]는 아랫줄의 i번째 수가 있는 윗줄의 인덱스가 되므로, 이제 세그먼트 트리로 매우 쉽게 처리해줄 수 있다.

 

 

더보기
#include <bits/stdc++.h>
#define int long long
using namespace std;

struct segmentTree {
    vector<int> v;

    void upd(int n, int b, int e, int idx, int val) {
        if(b == e) {
            v[n] += val;
            return;
        }

        if(idx <= (b+e)/2) upd(n*2, b, (b+e)/2, idx, val);
        else upd(n*2 + 1, (b+e)/2 + 1, e, idx, val);

        v[n] = v[n*2] + v[n*2 + 1];
    }

    int sum(int n, int b, int e, int l, int r) {
        if(r < b || e < l) return 0;
        if(l <= b && e <= r) return v[n];

        return sum(n*2, b, (b+e)/2, l, r)
                + sum(n*2 + 1, (b+e)/2 + 1, e, l, r);
    }
};

main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL), cout.tie(NULL);

    int T; cin >> T;

    while(T--) {
        int N; cin >> N;

        segmentTree f; f.v.resize(N*4);

        vector<int> v(N+1), u(N+1);

        for(int i=1; i<=N; i++) {
            int x; cin >> x;

            v[x] = i;
        }

        for(int i=1; i<=N; i++) {
            int x; cin >> x;

            u[i] = v[x];
        }

        int ans = 0;

        for(int i=1; i<=N; i++) {
            ans += (u[i] - 1) - f.sum(1, 1, N, 1, u[i]);

            f.upd(1, 1, N, u[i], 1);
        }

        cout << ans << "\n";
    }
}

 

 

 

반응형