[알고리즘] 선택 알고리즘 - 최악의 경우 선형 시간 선택 알고리즘

2024. 2. 26. 16:29Algorithms

숭실대학교 컴퓨터학부의 알고리즘 수업을 들으며 정리한 내용입니다.
참고교재: 쉽게 배우는 알고리즘(문병로)
선택 알고리즘 
1.선형 시간 선택 알고리즘(QuickSelect)
2.최악의 경우 선형 시간 선택 알고리즘

 

최악의 경우 선형 시간 선택 알고리즘

앞서 주어진 리스트에서 k번째로 작은(또는 큰) 요소를 찾는 선택 알고리즘을 구현했습니다.

하지만 최악의 경우, 선택한 기준 원소가 항상 최솟값이나 최댓값이 되어, 분할이 매우 불균형하게 이루어지면 재귀 호출이 깊어져 Θ(n²) 시간이 걸립니다. 이를 해결하기 위해서 Median of Medians 알고리즘을 적용해 Θ(n)의 시간 복잡도를 보장할 수 있습니다.

 

Median of Medians 알고리즘 단계

  1. 그룹화:
    • 주어진 배열을 5개씩의 그룹으로 나눕니다. 마지막 그룹은 5개 미만의 원소를 가질 수 있습니다.
  2. 그룹의 중앙값 찾기:
    • 각 그룹 내의 원소들을 정렬하고, 정렬된 각 그룹의 중앙값(median)을 찾습니다.
    • 이 중앙값들을 모아서 새로운 배열을 만듭니다.
  3. 중앙값들의 중앙값 찾기:
    • 2단계에서 얻은 중앙값들의 배열에서 다시 중앙값을 찾습니다. 이 중앙값이 전체 배열의 "pivot" 역할을 합니다.
    • 중앙값들의 중앙값을 재귀적으로 찾습니다. (배열의 크기가 작다면, 정렬 후 중앙값을 직접 찾을 수도 있습니다.)
  4. 배열 분할:
    • 선택한 pivot을 기준으로 주어진 배열을 분할합니다.
    • pivot보다 작거나 같은 원소들은 왼쪽 부분 배열, pivot보다 큰 원소들은 오른쪽 부분 배열에 위치시킵니다.
  5. 재귀적 선택:
    • pivot의 위치를 기준으로 원하는 순위의 원소가 어느 부분 배열에 위치하는지 결정합니다.
    • 원하는 순위의 원소가 pivot과 같다면, pivot을 반환합니다.
    • 원하는 순위의 원소가 pivot보다 작다면, 왼쪽 부분 배열에서 재귀적으로 찾습니다.
    • 원하는 순위의 원소가 pivot보다 크다면, 오른쪽 부분 배열에서 재귀적으로 찾습니다.

시간 복잡도 분석

Median of Medians 알고리즘은 배열의 중앙값들을 이용해 배열을 분할함으로써 최악의 경우에도 선형 시간을 보장합니다. 분할 과정에서 중앙값의 중앙값을 사용하는 이유는 분할이 항상 일정한 비율로 이루어지도록 보장하기 위해서입니다.

  • 그룹화와 중앙값 찾기: 배열을 5개씩 그룹화하고 각 그룹의 중앙값을 찾는 데 O(n) 시간이 소요됩니다.
  • 중앙값들의 중앙값 찾기: 중앙값들의 배열에서 중앙값을 재귀적으로 찾는 데 T(n/5) 시간이 소요됩니다.
  • 배열 분할: pivot을 기준으로 배열을 분할하는 데 O(n) 시간이 소요됩니다.
  • 재귀 호출: 배열을 분할한 후, 적어도 30%의 원소가 제거되므로 남은 원소에 대해 T(7n/10) 시간이 소요됩니다.

따라서 전체 시간 복잡도는 다음과 같이 표현할 수 있습니다. 이 수식은 결국 선형 시간인 O(n)으로 귀결됩니다.

 

 

최악의 경우 선형 시간 선택 코드

다음코드는 일반적인 선택 알고리즘과 Median of Medians을 적용한 선택알고리즘의 수행시간을 비교하는 코드입니다.

#include <iostream>
#include <math.h>
#include <chrono>

using namespace std;

//swap method
void swap(int* a, int* b) {
    int temp = *a;
    *a = *b;
    *b = temp;
}

/*--------------------------------------------------------select----------------------------------------------------------------------*/

int partition(int arr[], int low, int high) {
    int pivot = arr[high];
    int i = low;
    for (int j = low; j < high; j++) {
        if (arr[j] < pivot) {
            int temp = arr[i];
            arr[i] = arr[j];
            arr[j] = temp;
            i++;
        }
    }
    swap(&arr[i],&arr[high]);
    return i;
}
int select(int arr[], int p, int r, int i) {
    if (p == r)
        return arr[p];

    int q = partition(arr, p, r);
    int k = q - p + 1;

    if (i < k)
        return select(arr, p, q - 1, i);
    else if (i == k)
        return arr[q];
    else
        return select(arr, q + 1, r, i - k);
}

/*-------------------------------------------------------------------heapSort methods-------------------------------------------------------------*/

void heapify(int arr[], int i, int n) {
    int smallest = i;
    int l = 2 * i;
    int r = 2 * i + 1;

    if (l <= n && arr[l] < arr[smallest])
        smallest = l;
    if (r <= n && arr[r] < arr[smallest])
        smallest = r;
    if (smallest != i) {
        swap(&arr[i], &arr[smallest]);
        heapify(arr, smallest, n);
    }
}

void buildHeap(int arr[], int n) {
    for (int i = n / 2; i >= 1; i--)
        heapify(arr, i, n);
}

void heapSort(int arr[], int n) {
    buildHeap(arr, n);
    for (int i = n; i >= 2; i--) {
        swap(&arr[1], &arr[i]);
        heapify(arr, 1, i - 1);
    }
}


/*-----------------------------------------------------------Linear select-----------------------------------------------------------------*/
int linearSelect(int arr[], int p, int r, int i) {
    //1
    if ((r - p + 1) <= 5) {
        return select(arr,p,r,i);
    }

    //2
    int numOfGroups = ceil((double)(r - p + 1) / 5);

    //3
    int *medianArr = new int[numOfGroups];
    for (int j = 0; j < numOfGroups; j++) {
        heapSort(arr + p + j * 5-1, min(5, r - p + 1 - j*5));
        medianArr[j] = arr[p + j * 5 + min(2, r - p + 1 - j*5 - 1)];
    }

    //4
    int M = linearSelect(medianArr, 0, numOfGroups - 1, numOfGroups / 2 + 1);
    delete[] medianArr;

    //5
    for (int j = p; j <= r; j++) {
        if (arr[j] == M) {
            swap(&arr[j], &arr[r]);
            break;
        }
    }
    int q = partition(arr, p, r);

    //6
    if (i == q - p + 1)
        return arr[q];
    else if (i < q - p + 1)
        return linearSelect(arr, p, q - 1, i);
    else
        return linearSelect(arr, q + 1, r, i - q + p - 1);
}

/*------------------------------------------------------test---------------------------------------------------*/
int main() {
    int n = 3000;
    int* arr1 = new int[n];
    int* arr2 = new int[n];

    srand(time(0));

    for (int i = 0; i < n; i++) {
        arr1[i] = rand()% 3000;
        arr2[i] = rand()% 3000;
    }

    for(int repetition=1; repetition<=10; repetition++){
        cout << "Repetition: " << repetition << endl;

        //test select
        clock_t start1 = clock();
        for(int i=0; i<repetition; i++){
            for (int k = 1; k <= 1500; k++) {
                int kthSmallest_select = select(arr1, 0, n - 1, k);
            }
        }
        clock_t end1 = clock();

        //test linearSelect
        clock_t start2 = clock();
        for(int i=0; i<repetition; i++){
            for (int k = 1; k <= 1500; k++) {
                int kthSmallest_select = linearSelect(arr2, 0, n - 1, k);
            }
        }
        clock_t end2 = clock();

        //result
        cout<<"Select Execution Time: "<<(double)(end1-start1)/(double)1000000<<endl;
        cout<<"Linear Select Execution Time: "<<(double)(end2-start2)/(double)1000000<<endl;
    }

    delete[] arr1;
    delete[] arr2;
    return 0;
}

실행결과

테스트 케이스가 많아질수록 Median of Medians가 적용된 알고리즘이 빠른것을 알 수 있습니다.

일반 선택 알고리즘
1
8.365
2
18.8
3
29.943
4
39.619
5
51.261
6
76.381
7
83.921
8
93.496
9
113.391
10
112.768 
최악의 경우 선형시간 선택 알고리즘
1
0.299
2
0.628
3
0.959
4
1.34
5
1.646
6
3.044
7
2.356
8
3.53
9
4.653
10
3.468