基本的なソートアルゴリズムまとめ

前に買って積ん読行きとなっていた「プログラミングコンテスト攻略のためのアルゴリズムとデータ構造」、通称「螺旋本」を少しずつ読み進めていきたいと思います。

今回は第3章のソートアルゴリズムについてです。

ソートとは簡単に言ってしまえば並び替えのことです。 配列を値の大きい順やアルファベット順に並び替えるアルゴリズムです。

挿入ソート

要素を初めから順に取り出して、すでに並び替え済みのものの中の適切な位置に入れていくアルゴリズムです。 実装する上では(昇順ソートだとして)、取り出して空いた隙間に、取り出した値より前にあってかつ大きい値をとるものを、後ろに一つずつずらしていくという方法をとります。

AIZU ONLINE JUDGEの問題を解いて実装とします。(以降の実装も同様です)

LDS1_1_A.py
N = int(input())
A = list(map(int,input().split()))
print(" ".join([str(a) for a in A]))

def insertionSort(N,A):
    for i in range(1,N):
        v = A[i]
        j = i-1
        while A[j] > v and j >= 0:
            A[j+1] = A[j]
            j -= 1
        A[j+1] = v
        return A

print(" ".join([str(a) for a in insertionSort(N,A)]))

計算量は元のリストが降順に並んでいた場合に最悪で O(N2)O(N^2) となります。逆に、すでに昇順で並んでいた場合は O(N)O(N) です。 元のリストがゴールに近い状態であるほど計算量が小さいアルゴリズムです。

バブルソート

泡が水面に上がっていくように、要素を最小値から順に前に移動させていくアルゴリズムです。 実装する上では、後ろから順に隣り合う二つの要素を比べて大小関係が逆なら交換を行う操作を繰り返します。交換が一度も行われなくなったらソート完了です。

ALDS1_2_A.py
N = int(input())
A = list(map(int,input().split()))

def bubbleSort(N,A):
    cnt = 0
    flag = 1
    while flag == 1:
        flag = 0
        for j in range(N-1,0,-1):
            if A[j] < A[j-1]:
                tmp = A[j]
                A[j] = A[j-1]
                A[j-1] = tmp
                flag = 1
                cnt += 1
    return A,cnt

A,cnt = bubbleSort(N,A)
print(" ".join([str(s) for s in A]))
print(cnt)

前から順にソート済み配列を作っていくイメージなので、whileの中身がi回終わった時点でリストの前からi個の要素はすでに確定しています。 そこで、iをwhileの中身を繰り返した回数として、8行目のrange(N-1,0,-1)の部分をrange(N-1,i,-1)としてやると少し計算量が減ります。 いずれにせよ O(N2)O(N^2) のアルゴリズムです。 バブルソートの要素の交換回数は転倒数などと呼ばれ、リストの乱れ具合の指標とされるそうです。

選択ソート

ii番目以降のリストの中の最小値をii番目に移動させていくアルゴリズムです。 その時点で一番小さいものを選んで前に持ってくる、という直感的な操作となっています。

ALDS1_2_B.py
N = int(input())
A = list(map(int,input().split()))

def selectionSort(N,A):
    cnt = 0
    for i in range(N):
        minj = i
        for j in range(i,N):
            if A[j] < A[minj]:
                minj = j
        
        if minj != i:
            tmp = A[minj]
            A[minj] = A[i]
            A[i] = tmp
            cnt += 1
    return A,cnt

A,cnt = selectionSort(N,A)
print(" ".join([str(s) for s in A]))
print(cnt)

同じ値の要素の並びがソート前後で変化しない方法を安定なソートと言います。 挿入ソートやバブルソートが安定なソートなのに対して、選択ソートは安定なソートではありません。

シェルソート

「元のリストから一定間隔 gg で値を抜き出した全てのパターンそれぞれに対して挿入ソートを行う」という操作を、gg を小さくしながら繰り返すアルゴリズムです。

最終的に g=1g=1、つまり通常の挿入ソートを行います。既に整列がある程度済んでいるリストに対して、挿入ソートが O(N)O(N) に近い計算量で行えることを利用した高速な方法です。

実装は(汚いですが)例えば次のようになります。

ALDS1_2_D.py
N = int(input())
A = []
for i in range(N):
    A.append(int(input()))

def insertionSort(N,A,g):
    global cnt
    for i in range(g,N):
        v = A[i]
        j = i-g
        while A[j] > v and j >= 0:
            A[j+g] = A[j]
            j -= g
            cnt += 1
        A[j+g] = v
    return A    
      
def shellSort(N,A):
    global cnt
    cnt = 0
    m = 1
    G = [1]
    x = 1
    for i in range(N):
        x = 3*x + 1
        if x > N:
            break
        G.insert(0,x)
        m += 1
    
    for i in range(m):
        A = insertionSort(N,A,G[i])
    
    return cnt,A,m,G

cnt,A,m,G = shellSort(N,A)
print(m)
print(" ".join([str(g) for g in G]))
print(cnt)
for i in range(len(A)):
    print(A[i])

ggの選び方としては色々あるそうですが、

gn+1=3gn+1g0=1g_{n+1} = 3g_n+1\quad g_0 = 1

の漸化式で表される数列を ggNN を超えない範囲で打ち切ったものを採用すると計算量が O(N1.25)O(N^{1.25}) 程度で済むそうです。

他にも gg の減少列は考えられますが、例えば g=8,4,2,1g=8,4,2,1 とすると効率が悪くなりそうなことは感覚的にわかると思います。( g=1g=1になるまで同じような組み合わせでのソートばかり行われて無駄が多いためです)

計算速度比較

シェルソートが本当に高速なのか、通常の挿入ソートやバブルソートと比べてみます。 次のコードを用意します。

import random
import time

N = int(1e4)
A = [random.randint(0,N) for i in range(N)]


def insertionSort(N,A,g):
    global cnt
    for i in range(g,N):
        v = A[i]
        j = i-g
        while A[j] > v and j >= 0:
            A[j+g] = A[j]
            j -= g
        A[j+g] = v

    return A    
      
def shellSort(N,A):
    m = 1
    G = [1]
    x = 1
    for i in range(N):
        x = 3*x + 1
        if x > N:
            break
        G.insert(0,x)
        m += 1
    
    for i in range(m):
        A = insertionSort(N,A,G[i])
    
    return A

def bubbleSort(N,A):
    flag = 1
    while flag == 1:
        flag = 0
        for j in range(N-1,0,-1):
            if A[j] < A[j-1]:
                tmp = A[j]
                A[j] = A[j-1]
                A[j-1] = tmp
                flag = 1
    return A

t1 = time.time()
sortedA1 = insertionSort(N,A.copy(),1)
t2 = time.time()
sortedA2 = bubbleSort(N,A.copy())
t3 = time.time()
sortedA3 = shellSort(N,A.copy())
t4 = time.time()

flag = "Yes"
for i in range(N):
    if sortedA1[i] == sortedA2[i] and sortedA2[i] == sortedA3[i]:
        pass
    else:
        flag = "No"


print("かかった時間\n挿入ソート:{:.2f}s\nバブルソート:{:.2f}s\
        \nシェルソート:{:.2f}s\n\nソート結果が一致したか:{}"\
        .format(t2-t1,t3-t2,t4-t3,flag))

結果は例えば 挿入ソート:3.98s バブルソート:14.57s シェルソート:0.04s のようになります。 N=104N=10^4程度で調べてみると、シェルソートが圧倒的に早いことがわかりました。

また、減少列を g=,16,8,4,2,1g=\cdots,16,8,4,2,1 のように変えて試してみると、g=,13,4,1g=\cdots,13,4,1 のときが0.04sだったのに対し、0.10sになりました。 遅くはなったものの他のソート方法と比べると圧倒的に速いですね。

今回のまとめ

  • 簡単なソートはだいたい O(N2)O(N^2)
  • 同じ値を持つ要素の順番が前後で変わらないようなソートを安定なソートという
  • シェルソートは適切な減少列を設定すると速い

参考
プログラミングコンテスト攻略のためのアルゴリズムとデータ構造」, 渡部有隆

BACK