코딩테스트/알고리즘

8-2. 바이너리 인덱스 트리(BIT) (= 펜윅 트리 )

초코chip 2024. 4. 2. 20:20

개념

  • 언제 사용?: 데이터 업데이트가 가능한 상황에서 구간 합를 구해야할 경우에

  • 정의: 2진법 인덱스 구조를 활용해 구간 합 문제를 효과적으로 해결할 수 있는 자료구조

 

  • 특정 숫자 K의 비트 중 0이 아닌 마지막 비트 찾는 방법: K & - K

 

 

구현 방법

BIT 구조 만들기

0이 아닌 마지막 비트 = 내가 저장하고 있는 값들의 개수

예:

k = 16의 마지막 비트는 16 > 1~16까지의 구간합을 저장한다는 의미

k = 7의 마지막 비트는 1 > 자기 자신의 값만 저장한다는 의미

'

코드

// 데이터의 개수(n), 변경 횟수(m), 구간 합 계산 횟수(k)
private static int n, m, k;

// 전체 데이터의 개수는 최대 1,000,000개
private static long[] arr = new long[n+1];
private static long[] tree = new long[n+1]; //BIT 트리 배열(누적합)

 

BIT - 특정 값 업데이트

개념

  • 방법: 0이 아닌 마지막 비트만큼 더하면서 구간들의 값을 변경
  • 예시:
    • k = 3인 구간의 값을 업데이트 한다
      1. 3의 마지막 비트는 1 > 3+1 = 4의 구간합 변경
      2. 4의 마지막 비트는 4 > 4+4 = 8의 구간합 변경
      3. 8의 마지막 비트는 8 > 8+8 = 16의 구간합 변경
    • 총 4번의 업데이트가 발생
  • 시간 복잡도: O(logN)

 

코드

// i번째 수를 dif(변화량)만큼 더하는 함수
private static void update(int i, long dif) {
    while(i <= n) {
        tree[i] += dif;
        i += (i & -i);
    }
}

 

 

BIT - 누적 합(Prefix Sum) 도출

개념

  • 방법 - 1~K까지의 합 구하기: 0이 아닌 마지막 비트만큼 뺴면서 구간들의 합 계산
  • 예: 
    • 1~11(k=11)까지의 합 구하기
      1. 11의 마지막 비트는 1 > 11-1 = 10
      2. 10의 마지막 비트는 2 > 10-2 = 8
      3. 8의 마지막 비트는 8 > 8-8 = 0
    • 총 4번의 과정 진행
  • 시간 복잡도: O(logN)

코드

// 1~i번째 수까지의 누적 합을 계산하는 함수
private static long prefixSum(int i) {
    long result = 0;
    while(i > 0) {
        result += tree[i];
        // 0이 아닌 마지막 비트만큼 빼가면서 이동
        i -= (i & -i);
    }
    return result;
}

 

BIT - 구간합 도출

개념

  • 방법 - start ~ end까지의 구간합: prefixSum(end) - prefixSum(start-1)

 

코드

// start부터 end까지의 구간 합을 계산하는 함수
private static long intervalSum(int start, int end) {
    return prefixSum(end) - prefixSum(start - 1);
}

 

메인 함수

public static void main(String[] args) {
        Scanner scanner = new Scanner(System.in);
        
        n = scanner.nextInt();
        m = scanner.nextInt();
        k = scanner.nextInt();
        
        //1. BIT 트리 생성
        arr = new long[n+1];
        tree = new long[n+1];
        
        //2. 초기 값 입력 + BIT 트리 초기화
        for(int i = 1; i <= n; i++) {
            long x = scanner.nextLong();
            arr[i] = x;
            update(i, x); //초기 변화량은 단순 x
        }
        
        int count = 0;
        while(count++ < m + k) {
            int op = scanner.nextInt();
            // 업데이트(update) 연산인 경우
            if(op == 1) {
                int index = scanner.nextInt();
                long value = scanner.nextLong();
                update(index, value - arr[index]); // 바뀐 크기(dif)만큼 적용
                arr[index] = value; // i번째 수를 value로 업데이트
            }
            // 구간 합(interval sum) 연산인 경우
            else {
                int start = scanner.nextInt();
                int end = scanner.nextInt();
                System.out.println(intervalSum(start, end));
            }
        }
        scanner.close();
    }

 

전체 코드

import java.util.Scanner;

public class Main {
    // 전체 데이터의 개수는 최대 1,000,000개
    private static long[] arr = new long[1000001];
    private static long[] tree = new long[1000001];
    // 데이터의 개수(n), 변경 횟수(m), 구간 합 계산 횟수(k)
    private static int n, m, k;

    // i번째 수까지의 누적 합을 계산하는 함수
    private static long prefixSum(int i) {
        long result = 0;
        while(i > 0) {
            result += tree[i];
            // 0이 아닌 마지막 비트만큼 빼가면서 이동
            i -= (i & -i);
        }
        return result;
    } 

    // i번째 수를 dif만큼 더하는 함수
    private static void update(int i, long dif) {
        while(i <= n) {
            tree[i] += dif;
            i += (i & -i);
        }
    }

    // start부터 end까지의 구간 합을 계산하는 함수
    private static long intervalSum(int start, int end) {
        return prefixSum(end) - prefixSum(start - 1);
    }

    public static void main(String[] args) {
        Scanner scanner = new Scanner(System.in);
        n = scanner.nextInt();
        m = scanner.nextInt();
        k = scanner.nextInt();
        
        for(int i = 1; i <= n; i++) {
            long x = scanner.nextLong();
            arr[i] = x;
            update(i, x);
        }
        
        int count = 0;
        while(count++ < m + k) {
            int op = scanner.nextInt();
            // 업데이트(update) 연산인 경우
            if(op == 1) {
                int index = scanner.nextInt();
                long value = scanner.nextLong();
                update(index, value - arr[index]); // 바뀐 크기(dif)만큼 적용
                arr[index] = value; // i번째 수를 value로 업데이트
            }
            // 구간 합(interval sum) 연산인 경우
            else {
                int start = scanner.nextInt();
                int end = scanner.nextInt();
                System.out.println(intervalSum(start, end));
            }
        }
        scanner.close();
    }
}