본문 바로가기

알고리즘/백준

세그먼트 트리(Segment Tree) & [백준]2042 구간합(Python)

개념 및 특징 

 

  • 완전 이진 트리를 기반으로, 각 노드가 구간의 정보(구간합, 구간의 최솟값 등에 대한)를 저장하고 있는 자료구조
  • 리프 노드들은 길이가 1인 각각의 구간을 소유
  • 부모 노드는 자신의 자식 노드들의 구간의 합 소유 => 루트 노드는 전체 구간을 포함 
  • 모든 구간은 연속적

세그먼트 트리

 

 

시간 복잡도 및 공간 복잡도 

 

원하는 구간합을 구하고자 한다면,

 

시간 복잡도: O(logN) => 장점

공간 복잡도: 메모리가 많이 필요 => 구간합을 모두 저장해야하므로 => 단점 

 

구현 

 

1. Top → Down : 재귀 사용

2. Bottom → Up : For 반복문 사용, 코드 구현 간편 

 

  • 초기화 
def init(start, end, node):
    if start == end:    # 리프 노드
        tree[node] = nums[start]
        return

    mid = (start + end) // 2
    init(start, mid, node*2)    # 왼쪽 자식 노드의 구간합
    init(mid+1, end, node*2+1)  # 오른쪽 자식 노드의 구간합

    tree[node] = tree[node*2] + tree[node*2 + 1]    # 현재 노드 = 왼쪽 자식 노드의 구간합 + 오른쪽 자식 노드의 구간합

 

  • 업데이트 
def update(L, R, node, idx, val):   # idx: 바꿀 값의 index    value: 바꿀 값
    if L == R == idx:   # 리프 노드
        tree[node] = val
        return

    if idx < L or R < idx:  # 현재 노드의 구간에 idx가 포함되지 않으면 종료
        return

    mid = (L + R) // 2
    update(L, mid, node*2, idx, val)
    update(mid+1, R, node*2 + 1, idx, val)

    tree[node] = tree[node*2] + tree[node*2 + 1]    # 업데이트된 자식 노드들을 더해서 현재 노드의 값에 저장

업데이트

 

  • 구간합 
# L: 구하고자 하는 구간합의 왼쪽 구간
# R: 구하고자 하는 구간합의 오른쪽 구간
# node: 현재 노드
# nodeLeft: 노드의 왼쪽 구간
# nodeRight: 노드의 오른쪽 구간
def sum(L, R, node, nodeLeft, nodeRight):
    # 원하는 구간합의 구간 내에 현재 노드의 구간이 포함되지 않는다면 현재 노드의 값 불필요
    if R < nodeLeft or nodeRight < L:   
        return 0

    # 원하는 구간합의 구간 내에 현재 노드의 구간이 포함된다면 현재 노드의 값을 반환
    if L <= nodeLeft and nodeRight <= R:
        return tree[node]
    
    # 구간이 겹치는 경우에는 자식 노드에 대해 sum 함수 호출 
    mid = (nodeLeft + nodeRight) // 2
    return sum(L, R, node*2, nodeLeft, mid) + sum(L, R, node*2+1, mid+1, nodeRight)

구간합

 

관련문항 

https://www.acmicpc.net/problem/2042

 

2042번: 구간 합 구하기

첫째 줄에 수의 개수 N(1 ≤ N ≤ 1,000,000)과 M(1 ≤ M ≤ 10,000), K(1 ≤ K ≤ 10,000) 가 주어진다. M은 수의 변경이 일어나는 횟수이고, K는 구간의 합을 구하는 횟수이다. 그리고 둘째 줄부터 N+1번째 줄

www.acmicpc.net

코드
import sys
input = sys.stdin.readline

def init(start, end, node):
    if start == end:    # 리프 노드
        tree[node] = nums[start]
        return

    mid = (start + end) // 2
    init(start, mid, node*2)    # 왼쪽 자식 노드의 구간합
    init(mid+1, end, node*2+1)  # 오른쪽 자식 노드의 구간합

    tree[node] = tree[node*2] + tree[node*2 + 1]    # 현재 노드 = 왼쪽 자식 노드의 구간합 + 오른쪽 자식 노드의 구간합


def update(L, R, node, idx, val):   # idx: 바꿀 값의 index    value: 바꿀 값
    if L == R == idx:   # 리프 노드
        tree[node] = val
        return

    if idx < L or R < idx:  # 현재 노드의 구간에 idx가 포함되지 않으면 종료
        return

    mid = (L + R) // 2
    update(L, mid, node*2, idx, val)
    update(mid+1, R, node*2 + 1, idx, val)

    tree[node] = tree[node*2] + tree[node*2 + 1]    # 업데이트된 자식 노드들을 더해서 현재 노드의 값에 저장


# L: 구하고자 하는 구간합의 왼쪽 구간
# R: 구하고자 하는 구간합의 오른쪽 구간
# node: 현재 노드
# nodeLeft: 노드의 왼쪽 구간
# nodeRight: 노드의 오른쪽 구간
def sum(L, R, node, nodeLeft, nodeRight):
    # 원하는 구간합의 구간 내에 현재 노드의 구간이 포함되지 않는다면 현재 노드의 값 불필요
    if R < nodeLeft or nodeRight < L:
        return 0

    # 원하는 구간합의 구간 내에 현재 노드의 구간이 포함된다면 현재 노드의 값을 반환
    if L <= nodeLeft and nodeRight <= R:
        return tree[node]

    # 구간이 겹치는 경우에는 자식 노드에 대해 sum 함수 호출
    mid = (nodeLeft + nodeRight) // 2
    return sum(L, R, node*2, nodeLeft, mid) + sum(L, R, node*2+1, mid+1, nodeRight)


N, M, K = map(int, input().split())
nums = []   # 숫자 저장 리스트
tree = [0 for _ in range(4*N)]  # 세그먼트 트리 저장 리스트

for _ in range(N):
    num = int(input())
    nums.append(num)

init(0, N-1, 1) # 세그먼트 트리 초기화

for _ in range(M+K):
    a, b, c = map(int, input().split())

    if a == 1:  # b와 c를 바꾸고
        b -= 1
        update(0, N-1, 1, b, c)

    else:   # b부터 c까지 수의 합
        b -= 1
        c -= 1
        print(sum(b, c, 1, 0, N-1))

 

참고

https://one10004.tistory.com/241

'알고리즘 > 백준' 카테고리의 다른 글

[백준]16953 A->B(Python)  (0) 2024.03.31
[백준]11437 LCA(Python)  (0) 2024.02.14
[백준]16927 배열 돌리기2(Python)  (0) 2024.02.04
[백준]2096 내려가기(Python)  (1) 2024.02.03
[백준]3665 최종순위(Python)  (1) 2024.01.30