[EX] Merge Sort
Radix Sort와 더불어 익스 시험에서 가장 많이 쓰는 정렬이 아닐까 싶다. 시간 제한이 빡빡하지 않고, N은 작은데 M이 클 때 사용한다. (여기서 N은 배열의 원소 개수, M은 배열 원소가 갖는 값의 범위를 말한다.)
핵심 아이디어#
절반씩 나눠서 정렬한 다음, 그 결과를 합치면 된다.
조금 더 구체적으로 설명하면, 배열 arr[s,e)를 정렬한다고 하자. (여기서 arr[s,e)는 arr[s], arr[s+1], …, arr[e-1]을 의미한다.)
그러면, arr[s,(s+e)/2]가 이미 정렬되어 있고, arr[(s+e)/2,e)가 이미 정렬되어 있다 할 때, 이를 정렬하여 arr[s,e)를 만들려면 아래과 같이 구현하면 된다.
int m = (s+e)/2;
//여기서 이미 arr[s,m), arr[m,e)는 각각 정렬되어 있다고 하자.
//정렬된 두 배열을 합쳐 정렬된 배열 arr[s,e)를 구하는 방법은 아래와 같다.
int i=s, j=m, k=s;
while(i<m && j<e) {
if(arr[i]<arr[j]) tmp[k++] = arr[i++];
else tmp[k++] = arr[j++];
}
while(i<m) tmp[k++] = arr[i++];
while(j<e) tmp[k++] = arr[j++];
for(i=s; i<e; i++) arr[i]=tmp[i];
정렬 함수를 msort(int *arr, int s, int e)라고 하면, arr[s,m), arr[m,e)가 각각 정렬된 상태는 msort(arr,s,m), msort(arr,m,e)로 나타낼 수 있다. 이를 코드로 나타내면 아래와 같다. 이는 곧 msort(arr,s,e)를 구하는 재귀함수 코드가 된다.
int m = (s+e)/2;
msort(arr,s,m); //이 부분이 추가되었다.
msort(arr,m,e); //이 부분이 추가되었다.
int i=s, j=m, k=s;
while(i<m && j<e) {
if(arr[i]<arr[j]) tmp[k++] = arr[i++];
else tmp[k++] = arr[j++];
}
while(i<m) tmp[k++] = arr[i++];
while(j<e) tmp[k++] = arr[j++];
for(i=s; i<e; i++) arr[i]=tmp[i];
재귀함수 코드를 적기 위해선 재귀 종료 조건이 있어야 한다. msort() 함수는 내부적으로 arr을 계속해서 반씩 쪼개고 합치면서 정렬을 하므로, 이 쪼개는 행위를 더 이상 할 수 없는 지점에서 재귀호출하지 않도록 해야한다. 따라서, s+1==e일 때 재귀호출을 종료 하면 된다.
완성된 코드는 아래와 같다
int tmp[1000]; //이 부분이 추가되었다.
void msort(int *arr, int s, int e) { //이 부분이 추가되었다.
if(s+1>=e) return; //이 부분이 추가되었다.
int m = (s+e)/2;
msort(arr,s,m);
msort(arr,m,e);
int i=s, j=m, k=s;
while(i<m && j<e) {
if(arr[i]<arr[j]) tmp[k++] = arr[i++];
else tmp[k++] = arr[j++];
}
while(i<m) tmp[k++] = arr[i++];
while(j<e) tmp[k++] = arr[j++];
for(i=s; i<e; i++) arr[i]=tmp[i];
}
최종 코드#
마지막으로 시험장에서 잘 쓸 수 있게 구조화하면 아래와 같다.
나는 시험장에서 항상 이런 형태로 Merge Sort를 구현해서 쓴다.
template <typename T, int N>
struct MergeSort{
T tmp[N];
void Run(T *v, int s, int e) {
if(s+1>=e) return;
int m = (s+e)>>1;
Run(v,s,m); Run(v,m,e);
int i=s, j=m, k=s;
while(i<m && j<e) {
if(v[i]<v[j]) tmp[k++] = v[i++];
else tmp[k++] = v[j++];
}
while(i<m) tmp[k++] = v[i++];
while(j<e) tmp[k++] = v[j++];
for(i=s; i<e; i++) v[i]=tmp[i];
}
};시간 복잡도#
msort() 함수가 배열 arr[N]을 원소 하나만 남을 때까지 반으로 계속 나누므로, 그 깊이는 logN이 된다. 그리고 각 깊이마다 배열 N개를 모두 들여다 봐야하므로 시간복잡도는 O(NlogN)이 된다.