Algorithm

[백준/파이썬] 2751 수 정렬하기2 - 합병 정렬 구현하기

마크투비 2022. 9. 25. 21:02

 2751번 수 정렬하기2


 문제

N개의 수가 주어졌을 때, 이를 오름차순으로 정렬하는 프로그램을 작성하시오.

 입력

첫째 줄에 수의 개수 N(1 ≤ N ≤ 1,000,000)이 주어진다. 둘째 줄부터 N개의 줄에는 수가 주어진다. 이 수는 절댓값이 1,000,000보다 작거나 같은 정수이다. 수는 중복되지 않는다.

 출력

첫째 줄부터 N개의 줄에 오름차순으로 정렬한 결과를 한 줄에 하나씩 출력한다.

 

💡 구현 아이디어


 

이 문제는 당연히 sort() 함수를 사용하여 해결하는 것이 가장 간단하고, 그게 맞는 방법이다. 이 풀이는 그저 합병 정렬 알고리즘을 연습하기 위함이다. 

 

🔸 합병 정렬 (Merge sort)

합병 정렬은 분할 정복(divide and conquer) 방식으로 설계되었다. 한 번에 해결할 수 없는 문제를 작은 문제로 분할하여 해결하는 알고리즘이다. 주로 재귀 함수로 구현하고, 다음과 같이 3단계로 이루어진다.

 

1. Divide: 문제 분할 - 하나의 배열을 반으로 나눈다

2. Conquer: 분할된 작은 문제 해결 - 나뉜 배열들을 각각 정렬한다

3. Combine: 해결된 문제들을 다시 합침 - 다시 하나의 배열로 합친다

 

출처: 위키백과

 

합병 정렬 알고리즘의 시간 복잡도는 O(n logn)이다. 버블 정렬의 경우는 O(n^2)이다. 파이썬의 sort() 함수의 시간 복잡도도 궁금해서 찾아본 결과 파이썬의 정렬 라이브러리는 최악의 경우에도 시간 복잡도 O(n logn)을 보장한다고 한다. 파이썬의 주요 함수 및 연산자들의 시간 복잡도를 정리한 페이지가 있다. 

 

실전에서 정렬 문제를 만났다면 직접 정렬 알고리즘을 구현하려는 생각은 하지 말고 빠르게 sort() 함수를 사용하는 것이 현명한 방법이다. 아래 소스코드에서 2가지 방법 모두 확인해볼 수 있다. 

 

💻 소스 코드 1 (합병 정렬 구현하기)


# 합병정렬
import sys
input = sys.stdin.readline

# 입력
n = int(input())
list = [int(input().rstrip()) for _ in range(n)]

# 합병 정렬
def mergeSort(list):
    if len(list) <= 1:
        return list
    
    # 1) Divide - 배열 반으로 나누기
    mid = len(list) // 2

    # 2) Conquer - 나뉜 배열들을 각각 정렬하기
    left = mergeSort(list[:mid])
    right = mergeSort(list[mid+1:])

    # 3) Combine - 다시 하나의 배열로 합치기
    merge(left, right)

def merge(left, right):
    sorted = []
    p1, p2 = 0, 0 # 나뉜 배열 2개를 각각 가리키는 인덱스

    while len(left) > p1 and len(right) > p2:
        if left[p1] > right[p2]:
            sorted.append(right[p2])
            p2 += 1
        else:
            sorted.append(left[p1])
            p1 += 1
        
    while len(left) > p1 and len(right) <= p2:
        sorted.append(left[p1])
        p1 += 1
    
    while len(right) > p2 and len(left) <= p1:
        sorted.append(right[p2])
        p2 += 1

mergeSort(list)

# 출력
for i in sorted:
    print(i)

 

💻 소스 코드 2 (sort() 함수 사용하기)


import sys
input = sys.stdin.readline

# 입력
n = int(input())
li = [int(input()) for _ in range(n)]

# 정렬
li.sort()

# 출력
print(*li, sep='\n')