알고리즘/백준(BOJ) 문제풀이

[C++ 백준 풀이] 2042번 : 구간 합 구하기 (세그먼트 트리, Segment Tree)

restudy 2021. 12. 10. 00:34
반응형

이 포스트는 프로그래밍 문제 사이트 백준 Online Judge의 2042번 : '구간 합 구하기' 문제의 풀이 코드와 해설을 다루고 있습니다.

 

문제의 난이도는 Solved.ac 기준 Gold I에 해당하며, 이 문제를 풀기 위해 세그먼트 트리(Segment Tree)의 개념을 다룰 것입니다.

 

 

2042번 : 구간 합 구하기

 

N개의 수가 입력되고, 두 가지 종류의 쿼리를 처리하는 문제입니다.

쿼리 1은 b번째 숫자를 c로 바꾸는 처리이고, 쿼리 2는 b번째부터 c번째까지의 수의 합을 구하여 출력하는 처리입니다.

이 문제를 풀기 위해서는 단순히 구간 합을 미리 구해서 배열에 저장하는 방법으로는 시간 제한과 메모리 초과로 풀 수 없으므로 세그먼트 트리를 활용해야 합니다.

세그먼트 트리란 구간 합을 이분적으로 쪼개서 저장하고 있는 트리입니다.

특정 노드의 키를 수정하는데 걸리는 시간과 구간 합을 구하는데 걸리는 시간이 모두 O(log N)이므로 여러 개의 쿼리를 처리하기에도 적절합니다.

 

그럼 세그먼트 트리를 구현하고 설명하도록 하겠습니다.

 

 

↓ 위의 이미지와 동일한 풀이 코드는 아래의 접은 글에 정리되어 있습니다.

더보기
#include <cstdio>
#include <cmath>
#include <vector>
using namespace std;

vector<long long> Arr;
vector<long long> SegmentTree;
vector<pair<int, pair<int, long long>>> Query;


long long CreateSegmentTree(int Begin, int End, int Node) {
    if(Begin == End) return SegmentTree[Node] = Arr[Begin];
    long long LeftKey = CreateSegmentTree(Begin, (Begin+End)/2, Node*2);
    long long RightKey = CreateSegmentTree((Begin+End)/2+1, End, Node*2+1);
    SegmentTree[Node] = LeftKey + RightKey;
    return SegmentTree[Node];
}

void UpdateSegmentTree(int Begin, int End, int Node, int Index, long long Diff) {
    if(Index < Begin || Index > End) return;
    SegmentTree[Node] += Diff;
    if(Begin < End) {
        UpdateSegmentTree(Begin, (Begin+End)/2, Node*2, Index, Diff);
        UpdateSegmentTree((Begin+End)/2+1, End, Node*2+1, Index, Diff);
    }
}

long long SumOfSegmentTree(int Begin, int End, int Node, int Left, int Right) {
    if(Left > End || Right < Begin) return 0;
    if(Left <= Begin && Right >= End) return SegmentTree[Node];
    long long LeftKey = SumOfSegmentTree(Begin, (Begin+End)/2, Node*2, Left, Right);
    long long RightKey = SumOfSegmentTree((Begin+End)/2+1, End, Node*2+1, Left, Right);
    return LeftKey + RightKey;
}

int main() {
    int N, M, K;
    scanf("%d %d %d", &N, &M, &K);
    for(int i=0; i<N; i++) {
        long long data;
        scanf("%lld", &data);
        Arr.push_back(data);
    }
    for(int i=0; i<M+K; i++) {
        int a, b; long long c;
        scanf("%d %d %lld", &a, &b, &c);
        Query.push_back({a, {b, c}});
    }
    int TreeHeight = (int)ceil(log2(N));
    int TreeSize = (1 << (TreeHeight+1));
    SegmentTree.resize(TreeSize);
    CreateSegmentTree(0, N-1, 1);
    for(int i=0; i<Query.size(); i++) {
        if(Query[i].first == 1) {
            int Index = Query[i].second.first-1;
            long long Diff = Query[i].second.second - Arr[Index];
            Arr[Index] = Query[i].second.second;
            UpdateSegmentTree(0, N-1, 1, Index, Diff);
        }
        else printf("%lld\n", SumOfSegmentTree(0, N-1, 1, Query[i].second.first-1, Query[i].second.second-1));
    }
}

 

1. 전체적인 구조

먼저 Arr 벡터를 선언해서 데이터들을 저장하도록 합니다.

벡터로 저장하는 이유는 소모되는 메모리를 최소화하기 위해서입니다.

그 다음 Query를 저장하는데, 3개의 수가 입력되므로 이중 pair로 저장해주어야 합니다.

특히 c의 경우 long long 범위까지 필요하기 때문에 데이터 타입을 long long으로 선언해주어야 합니다.

CreateSegmentTree 함수로 세그먼트 트리를 먼저 만들어줍니다.

이후 쿼리를 하나씩 처리하면서 쿼리 1에 대해서는 UpdateSegmentTree로 트리를 갱신해줍니다.

쿼리 2에 대해서는 SumOfSegmentTree 함수를 이용해 구간 합을 불러와서 출력해줍니다.

 

2. CreateSegmentTree 함수

Begin과 End 주소를 받고, 전체적으로는 Key 값을 리턴해주도록 합니다. (Begin ~ End 구간의 합에 해당)

재귀적으로 Key의 합을 구하는 함수이기 때문에, 구간의 길이가 1 즉 Start = End이면 Node의 Key 값으로 그냥 Arr 값을 리턴해주고, 그게 아닐 때는 재귀적으로 구간을 반으로 나눠 그 합을 리턴하도록 설계합니다.

 

3. UpdateSegmentTree 함수

특정 Index의 값을 수정하기 위해서는 트리를 루트에서부터 따라내려가면서 모든 노드를 수정해주어야 합니다.

이 역시 구간을 계속해서 반으로 나누어가면서 Index가 Begin~End 사이 구간에 있는지 확인하고 아니라면 탐색을 종료합니다.

범위에 해당한다면 미리 계산한 Diff 값을 더해서 값을 갱신해줍니다. (Diff 값을 미리 구한 이유는 연산의 횟수를 최소화하기 위함)

 

그리고 함수 밖의 쿼리 1을 처리하는 부분에서 주의해야 할 점이 있는데, Arr[Index] 값 역시 수정을 해주어야 합니다. (이미지의 코드에서 57행에 해당)

그 이유는 첫 번째 쿼리 1에서 구간 1짜리 Update를 수행할 때는 문제 없지만, 두 번째 쿼리 1에서 구간 1짜리 Update를 수행할 때 Arr 값이 변경되지 않으면 잘못된 Diff 값이 계산되기 때문입니다.

따라서 이미지 코드의 57행과 같이 Arr 값을 갱신해주어야 합니다.

 

4. SumOfSegmentTree 함수

구간 합 역시 재귀적으로 절반씩 쪼개면서 구간에 해당하는 구간 합을 합쳐서 리턴해주면 됩니다.

위의 함수와 대체로 비슷한 형태이므로 간략하게만 서술하고 넘어가겠습니다.

 

 

 

위의 코드대로 제출을 해보면 236ms의 채점 시간으로 모든 테스트케이스를 통과할 수 있습니다.

이 문제는 제가 지금까지 백준에서 풀이한 문제들 중에 가장 구현이 어려웠던 것 같습니다. (특히 쿼리 1 수행할 때 Arr 값 갱신 안해준 것 때문에 꽤 많이 헤맸습니다.)

이는 다른 이유도 있겠지만, 무엇보다도 제가 세그먼트 트리를 처음 구현해보았기 때문이 큰 것 같습니다.

앞으로는 세그먼트 트리를 활용하는 문제들을 위주로 풀면서 숙련도를 높이도록 하겠습니다.

 

 

 

반응형