[알고리즘] 선택 알고리즘 - 선형 시간 선택 알고리즘(QuickSelect)

2024. 2. 26. 16:18Algorithms

숭실대학교 컴퓨터학부의 알고리즘 수업을 들으며 정리한 내용입니다.
참고교재: 쉽게 배우는 알고리즘(문병로)
선택 알고리즘 
1.선형 시간 선택 알고리즘(QuickSelect)
2.최악의 경우 선형 시간 선택 알고리즘

 

선택 알고리즘이란

 

 

선택 알고리즘은 주어진 리스트에서 k번째로 작은(또는 큰) 요소를 찾는 알고리즘을 의미합니다. 이 알고리즘은 다양한 방법으로 구현될 수 있습니다.

 

선형 시간 선택 알고리즘은 최악의 경우에도 O(n)의 시간 복잡도를 가지는 알고리즘입니다. 이 알고리즘의 대표적인 예는 QuickSelect입니다. QuickSelect는 QuickSort 알고리즘을 기반으로 하며, 분할-정복 방식을 사용합니다. 피벗을 선정하고 이를 기준으로 리스트를 두 부분으로 나눈 후, k가 어느 부분에 있는지에 따라 해당 부분만 재귀적으로 탐색합니다. 이 방법은 평균적으로 O(n)의 시간 복잡도를 가집니다.

 

그러나 QuickSelect의 경우 최악의 경우에는 O(n^2)의 시간 복잡도를 가질 수 있습니다. 이를 방지하기 위해 Median of Medians라는 알고리즘을 사용할 수 있는데, 이 알고리즘은 피벗을 더 효과적으로 선택함으로써 항상 선형 시간 내에 k번째 요소를 찾을 수 있게 합니다. 이렇게 두 알고리즘을 합친 것을 '선형 시간 복잡도를 가진 선택 알고리즘'이라고 부르며, 이 알고리즘은 최악의 경우에도 O(n)의 시간 복잡도를 보장합니다.

 

QuickSelect 코드

int partition(int arr[], int left, int right){
    int pivot = arr[right];
    int i = (left - 1);

    for(int j = left; j < right; j++){
        if(arr[j] < pivot){
            i++;
            swap(arr[i],arr[j]);
        }
    }

    swap(arr[i+1], arr[right]);
    return i + 1;  // 피벗의 위치를 반환
}


int select(int arr[], int left, int right, int i){
    //원소가 하나일 경우
    if(left == right) return arr[left];

    int pivotIndex = partition(arr, left, right);

    if(i == pivotIndex)
        return arr[pivotIndex];
    else if(i < pivotIndex)
        return select(arr, left, pivotIndex - 1, i);
    else
        return select(arr, pivotIndex + 1, right, i);
}

 

QuickSelect 과정

1. partition 함수: 배열 `arr`의 `left`부터 `right`까지의 부분 배열에서 피벗을 설정하고, 피벗보다 작은 요소들은 모두 피벗의 왼쪽으로, 피벗보다 큰 요소들은 모두 피벗의 오른쪽으로 이동시키는 작업을 수행합니다. 이 함수는 피벗의 최종 위치를 반환합니다.

2. select 함수: 이 함수는 QuickSelect 알고리즘을 구현한 것입니다. `partition` 함수를 사용해서 피벗을 설정하고, 배열을 두 부분으로 분할합니다. 그 다음 `i`가 피벗의 위치보다 작은지, 큰지에 따라 해당 부분 배열에서 재귀적으로 k번째 작은 요소를 찾습니다. 만약 `i`가 피벗의 위치와 같다면, 피벗이 바로 k번째 작은 요소이므로 피벗을 반환합니다. 따라서 이 코드는 주어진 배열 `arr`에서 `i`번째로 작은 요소를 빠르게 찾는데 사용됩니다.

테스트 코드

#include <iostream>
#include <cstdlib>
#define SIZE 10
using namespace std;

int partition(int arr[], int left, int right){
    int pivot = arr[right];
    int i = (left - 1);

    for(int j = left; j < right; j++){
        if(arr[j] < pivot){
            i++;
            swap(arr[i],arr[j]);
        }
    }

    swap(arr[i+1], arr[right]);
    return i + 1;  // 피벗의 위치를 반환
}


int select(int arr[], int left, int right, int i){
    //원소가 하나일 경우
    if(left == right) return arr[left];

    int pivotIndex = partition(arr, left, right);

    if(i == pivotIndex)
        return arr[pivotIndex];
    else if(i < pivotIndex)
        return select(arr, left, pivotIndex - 1, i);
    else
        return select(arr, pivotIndex + 1, right, i);
}

void printArr(int arr[], int len){
    for(int i=0; i<len; i++){
        cout<< arr[i] <<" ";
    }
    cout << endl;
}

int main() {
    srand(time(0));
    int arr[SIZE];
    for(int i=0; i<SIZE; i++){
        arr[i] = rand() % 100; // 0부터 99까지의 랜덤한 정수 생성
    }
    printArr(arr, SIZE);

    int k = 5; // 찾고자 하는 k번째 작은 요소의 위치
    int kthSmallest = select(arr, 0, SIZE - 1, k - 1);
    cout << k << "th smallest element is " << kthSmallest << endl;

    return 0;
}

 

 

테스트 결과

68 42 98 92 75 22 80 28 16 27 
5th smallest element is 42