알고리즘/알고리즘 공부 내용 정리

백준(BOJ) 세그먼트 트리 문제 풀이 모음 (Segment Tree)

restudy 2022. 6. 20. 00:04
반응형

이 포스트에서는 세그먼트 트리를 다루며, 특히 백준 Online Judge(BOJ)의 세그먼트 트리에 관련된 문제들을 풀이해보도록 하겠습니다.

 

 

 

문제를 풀기 전에, 세그먼트 트리(Segment Tree)는 언제 사용하면 될까요?

세그먼트 트리는 값의 갱신구간의 대표값을 구하는 것을 O(log N)의 시간에 수행할 수 있습니다. (구간의 합, 곱, 최댓값, 최솟값, 특정 값보다 큰 값의 개수)

따라서 구간의 어떤 값을 여러 번 구해야하는 경우에 유용하게 사용할 수 있습니다. (특히 쿼리 문제)

 

이제 문제들을 풀어봅시다.

각 문제의 풀이 코드는 접은 글 안에 정리되어 있습니다.

 

 

백준 BOJ 2042번 : 구간 합 구하기

수열에서 값의 갱신과 구간 합의 계산을 여러 번 수행해야 하는 문제입니다.

세그먼트 트리에 구간의 합을 저장함으로써 해결할 수 있습니다.

n번 노드가 특정 구간의 합을 저장한다고 할 때, 왼쪽 자식 노드인 n*2번 노드에는 왼쪽 구간의 합을, 오른쪽 자식 노드인 n*2 + 1번 노드에는 오른쪽 구간의 합을 저장하면 됩니다. (물론 tree를 initialize하는 과정에서는 노드의 값을 재귀적으로 구해줄 수 있습니다.)

 

더보기
#include <bits/stdc++.h>
#define int long long
using namespace std;

vector<int> v, u;

int init(int n, int b, int e) {
    if(b == e) return u[n] = v[b];

    int l = init(n*2, b, (b+e)/2);
    int r = init(n*2 + 1, (b+e)/2 + 1, e);

    return u[n] = l + r;
}

void update(int n, int b, int e, int idx, int diff) {
    if(idx < b || e < idx) return;

    u[n] += diff;

    if(b < e) {
        update(n*2, b, (b+e)/2, idx, diff);
        update(n*2 + 1, (b+e)/2 + 1, e, idx, diff);
    }
}

int f(int n, int b, int e, int l, int r) {
    if(r < b || e < l) return 0;
    if(l <= b && e <= r) return u[n];

    return f(n*2, b, (b+e)/2, l, r) + f(n*2 + 1, (b+e)/2 + 1, e, l, r);
}

main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL), cout.tie(NULL);

    int N, M, K; cin >> N >> M >> K;

    v.resize(N+1);
    for(int i=1; i<=N; i++) cin >> v[i];

    u.resize(N*4);
    init(1, 1, N);

    M += K;
    while(M--) {
        int Q, a, b; cin >> Q >> a >> b;

        if(Q == 1) {
            int diff = b - v[a];
            v[a] = b;
            update(1, 1, N, a, diff);
        }
        else if(Q == 2) cout << f(1, 1, N, a, b) << "\n";
    }
}

 

 

백준 BOJ 2357번 : 최솟값과 최댓값

다음과 같이 특정 구간의 최솟값이나 최댓값 역시 세그먼트 트리로 해결할 수 있습니다.

단순히 트리를 2개 선언하여 최솟값 트리, 최댓값 트리로 활용해주면 됩니다.

코드가 길어지기 때문에 변수를 짧게 선언하는 것을 추천합니다.

 

더보기
#include <bits/stdc++.h>
#define int long long
using namespace std;

vector<int> v, u, w;

int init_u(int n, int b, int e) {
    if(b == e) return u[n] = v[b];

    int lv = init_u(n*2, b, (b+e)/2);
    int rv = init_u(n*2 + 1, (b+e)/2 + 1, e);

    return u[n] = min(lv, rv);
}

int init_w(int n, int b, int e) {
    if(b == e) return w[n] = v[b];

    int lv = init_w(n*2, b, (b+e)/2);
    int rv = init_w(n*2 + 1, (b+e)/2 + 1, e);

    return w[n] = max(lv, rv);
}

int f(int n, int b, int e, int l, int r) {
    if(r < b || e < l) return INT_MAX;
    if(l <= b && e <= r) return u[n];

    int lv = f(n*2, b, (b+e)/2, l, r);
    int rv = f(n*2 + 1, (b+e)/2 + 1, e, l, r);

    return min(lv, rv);
}

int g(int n, int b, int e, int l, int r) {
    if(r < b || e < l) return INT_MIN;
    if(l <= b && e <= r) return w[n];

    int lv = g(n*2, b, (b+e)/2, l, r);
    int rv = g(n*2 + 1, (b+e)/2 + 1, e, l, r);

    return max(lv, rv);
}

main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL), cout.tie(NULL);

    int N, M; cin >> N >> M;

    v.resize(N+1);
    for(int i=1; i<=N; i++) cin >> v[i];

    u.resize(N*4);
    init_u(1, 1, N);

    w.resize(N*4);
    init_w(1, 1, N);

    while(M--) {
        int a, b; cin >> a >> b;
        cout << f(1, 1, N, a, b) << " " << g(1, 1, N, a, b) << "\n";
    }
}

 

 

백준 BOJ 11505번 : 구간 곱 구하기

구간 합 구하기와 비슷한 문제입니다.

데이터의 범위와 모듈로 연산에만 주의하여 코드를 작성해주면 됩니다.

 

더보기
#include <bits/stdc++.h>
#define int long long
using namespace std;

vector<int> v, u;
int mod = 1e9 + 7;

int init(int n, int b, int e) {
    if(b == e) return u[n] = v[b];

    int lv = init(n*2, b, (b+e)/2);
    int rv = init(n*2 + 1, (b+e)/2 + 1, e);

    return u[n] = (lv * rv) % mod;
}

int upd(int n, int b, int e, int idx, int val) {
    if(idx < b || e < idx) return u[n];
    if(b == e) return u[n] = val;

    int lv = upd(n*2, b, (b+e)/2, idx, val);
    int rv = upd(n*2 + 1, (b+e)/2 + 1, e, idx, val);

    return u[n] = (lv * rv) % mod;
}

int mul(int n, int b, int e, int l, int r) {
    if(r < b || e < l) return 1;
    if(l <= b && e <= r) return u[n];

    int lv = mul(n*2, b, (b+e)/2, l, r);
    int rv = mul(n*2 + 1, (b+e)/2 + 1, e, l, r);

    return (lv * rv) % mod;
}

main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL), cout.tie(NULL);

    int N, M, K; cin >> N >> M >> K;

    v.resize(N+1);
    for(int i=1; i<=N; i++) cin >> v[i];

    u.resize(N*4);
    init(1, 1, N);

    M += K;
    while(M--) {
        int Q, a, b; cin >> Q >> a >> b;

        if(Q == 1) upd(1, 1, N, a, b);
        else if(Q == 2) cout << mul(1, 1, N, a, b) << "\n";
    }
}

 

 

백준 BOJ 6549번 : 히스토그램에서 가장 큰 직사각형

세그먼트 트리를 잘 응용해야하는 문제입니다.

어떤 구간에서 가장 큰 직사각형의 넓이는 반드시 다음의 3가지 중에 하나입니다.

 

1. 구간의 최솟값 * 구간의 길이

2. 구간의 왼쪽 끝 ~ (구간의 최솟값 주소 - 1) 구간에서 가장 큰 직사각형의 넓이

3. (구간의 최솟값 주소 + 1) ~ 구간의 오른쪽 끝 구간에서 가장 큰 직사각형의 넓이

 

따라서 구간이 주어지면 먼저 1번을 구해주고, 2번과 3번을 재귀적으로 수행하며 O(log N) 시간에 넓이의 최댓값을 구해주면 됩니다.

 

더보기
#include <bits/stdc++.h>
#define int long long
using namespace std;

int N;
vector<int> v, u;

int init(int n, int b, int e) {
    if(b == e) return u[n] = b;

    int lv = init(n*2, b, (b+e)/2);
    int rv = init(n*2 + 1, (b+e)/2 + 1, e);

    if(v[lv] < v[rv]) return u[n] = lv;
    else return u[n] = rv;
}

int g(int n, int b, int e, int l, int r) {
    if(r < b || e < l) return -1;
    if(l <= b && e <= r) return u[n];

    int lv = g(n*2, b, (b+e)/2, l, r);
    int rv = g(n*2 + 1, (b+e)/2 + 1, e, l, r);

    if(lv < 0) return rv;
    if(rv < 0) return lv;

    if(v[lv] < v[rv]) return lv;
    else return rv;
}

int f(int l, int r) {
    int idx = g(1, 1, N, l, r);
    int ret = (r-l+1)*v[idx];

    if(idx+1 <= r) ret = max(ret, f(idx+1, r));
    if(idx-1 >= l) ret = max(ret, f(l, idx-1));

    return ret;
}

main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL), cout.tie(NULL);

    while(true) {
        cin >> N;
        if(N == 0) break;

        v.resize(N+1);
        for(int i=1; i<=N; i++) cin >> v[i];

        u.resize(N*4);
        init(1, 1, N);

        cout << f(1, N) << "\n";
    }
}

 

 

백준 BOJ 1517번 : 버블 소트 (Inversion Counting Problem)

버블 소트에서 값들의 swap이 발생하는 경우는, 왼쪽 원소의 값이 오른쪽 원소의 값보다 큰 경우입니다.

따라서 각 원소들에 대해 자신의 오른쪽에 있는 더 작은 원소들의 개수를 구하여 합해주면 됩니다.

 

더보기
#include <bits/stdc++.h>
#define int long long
using namespace std;

vector<pair<int, int>> v;
vector<int> u;

void upd(int n, int b, int e, int idx) {
    if(b == e) {
        u[n] = 1;
        return;
    }

    if(idx <= (b+e)/2) upd(n*2, b, (b+e)/2, idx);
    else upd(n*2 + 1, (b+e)/2 + 1, e, idx);

    u[n] = u[n*2] + u[n*2 + 1];
}

int cnt(int n, int b, int e, int l, int r) {
    if(r < b || e < l) return 0;
    if(l <= b && e <= r) return u[n];

    int lv = cnt(n*2, b, (b+e)/2, l, r);
    int rv = cnt(n*2 + 1, (b+e)/2 + 1, e, l, r);

    return lv + rv;
}

main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL), cout.tie(NULL);

    int N; cin >> N;

    v.resize(N+1);
    for(int i=1; i<=N; i++) {
        cin >> v[i].first;
        v[i].second = i;
    }

    sort(v.begin()+1, v.end());

    u.resize(N*4);
    int ans = 0;

    for(int i=1; i<=N; i++) {
        ans += cnt(1, 1, N, v[i].second+1, N);
        upd(1, 1, N, v[i].second);
    }

    cout << ans << "\n";
}

 

 

 

반응형