[EX] Merge Sort

Radix Sort와 더불어 익스 시험에서 가장 많이 쓰는 정렬이 아닐까 싶다. 시간 제한이 빡빡하지 않고, N은 작은데 M이 클 때 사용한다. (여기서 N은 배열의 원소 개수, M은 배열 원소가 갖는 값의 범위를 말한다.)


핵심 아이디어#

절반씩 나눠서 정렬한 다음, 그 결과를 합치면 된다.

조금 더 구체적으로 설명하면, 배열 arr[s,e)를 정렬한다고 하자. (여기서 arr[s,e)는 arr[s], arr[s+1], …, arr[e-1]을 의미한다.)

그러면, arr[s,(s+e)/2]가 이미 정렬되어 있고, arr[(s+e)/2,e)가 이미 정렬되어 있다 할 때, 이를 정렬하여 arr[s,e)를 만들려면 아래과 같이 구현하면 된다.

int m = (s+e)/2;
//여기서 이미 arr[s,m), arr[m,e)는 각각 정렬되어 있다고 하자.
//정렬된 두 배열을 합쳐 정렬된 배열 arr[s,e)를 구하는 방법은 아래와 같다.
int i=s, j=m, k=s;                                 
while(i<m && j<e) {                                
    if(arr[i]<arr[j]) tmp[k++] = arr[i++];         
    else tmp[k++] = arr[j++];                      
}                                                  
while(i<m) tmp[k++] = arr[i++];                    
while(j<e) tmp[k++] = arr[j++];                    
for(i=s; i<e; i++) arr[i]=tmp[i];                  

정렬 함수를 msort(int *arr, int s, int e)라고 하면, arr[s,m), arr[m,e)가 각각 정렬된 상태는 msort(arr,s,m), msort(arr,m,e)로 나타낼 수 있다. 이를 코드로 나타내면 아래와 같다. 이는 곧 msort(arr,s,e)를 구하는 재귀함수 코드가 된다.

int m = (s+e)/2;
msort(arr,s,m); //이 부분이 추가되었다.
msort(arr,m,e); //이 부분이 추가되었다.

int i=s, j=m, k=s;                                 
while(i<m && j<e) {                                
    if(arr[i]<arr[j]) tmp[k++] = arr[i++];         
    else tmp[k++] = arr[j++];                      
}                                                  
while(i<m) tmp[k++] = arr[i++];                    
while(j<e) tmp[k++] = arr[j++];                    
for(i=s; i<e; i++) arr[i]=tmp[i];                  

재귀함수 코드를 적기 위해선 재귀 종료 조건이 있어야 한다. msort() 함수는 내부적으로 arr을 계속해서 반씩 쪼개고 합치면서 정렬을 하므로, 이 쪼개는 행위를 더 이상 할 수 없는 지점에서 재귀호출하지 않도록 해야한다. 따라서, s+1==e일 때 재귀호출을 종료 하면 된다.

완성된 코드는 아래와 같다

int tmp[1000]; //이 부분이 추가되었다.                                                             
void msort(int *arr, int s, int e) { //이 부분이 추가되었다.                                         
    if(s+1>=e) return; //이 부분이 추가되었다.
    int m = (s+e)/2;                                                         
    msort(arr,s,m);
    msort(arr,m,e);                                                                             
    int i=s, j=m, k=s;                                                       
    while(i<m && j<e) {                                                      
        if(arr[i]<arr[j]) tmp[k++] = arr[i++];                               
        else tmp[k++] = arr[j++];                                            
    }                                                                        
    while(i<m) tmp[k++] = arr[i++];                                          
    while(j<e) tmp[k++] = arr[j++];                                          
    for(i=s; i<e; i++) arr[i]=tmp[i];                                        
}                                                                             

최종 코드#

마지막으로 시험장에서 잘 쓸 수 있게 구조화하면 아래와 같다.

나는 시험장에서 항상 이런 형태로 Merge Sort를 구현해서 쓴다.

template <typename T, int N>
struct MergeSort{
    T tmp[N];
    void Run(T *v, int s, int e) {
        if(s+1>=e) return;
        int m = (s+e)>>1;

        Run(v,s,m); Run(v,m,e);

        int i=s, j=m, k=s;
        while(i<m && j<e) {
            if(v[i]<v[j]) tmp[k++] = v[i++];
            else tmp[k++] = v[j++];
        }
        while(i<m) tmp[k++] = v[i++];
        while(j<e) tmp[k++] = v[j++];
        for(i=s; i<e; i++) v[i]=tmp[i];

    }
};

시간 복잡도#

msort() 함수가 배열 arr[N]을 원소 하나만 남을 때까지 반으로 계속 나누므로, 그 깊이는 logN이 된다. 그리고 각 깊이마다 배열 N개를 모두 들여다 봐야하므로 시간복잡도는 O(NlogN)이 된다.