matsulibの日記

Ingredients as Code

Pythonでリストからn番目に大きな要素を取り出す

最近はずっとプログラミングできてなかったけど、ふと前に授業で使ってたノートを見たら「縮小法」なるアルゴリズムが載っていた。これはリストからn番目に大きな要素を取り出すアルゴリズムで、英語版wikipediaではmedian of mediansという項目に詳しく書かれている。
Median of medians - Wikipedia, the free encyclopedia
結論から言えばO(n)で計算できるからソートを使った方法O(nlogn)より計算量が少ないという。
そう言えばそういうのあったなあ、という感覚だけどせっかくだからリハビリがてら実装してみることにした。
同じことをやるアルゴリズム英語版wikipedia
Selection algorithm - Wikipedia, the free encyclopedia
に色々と書かれているけど、最後の方にPythonは標準ライブラリにheapq.nlargest()という関数があって、それを使うとO(n log k) で計算できるらしいのでついでに試してみた。

ソースコード

from pylab import *
import time
import random
import heapq


def select_sorted(L, k):
    return sorted(L, reverse=True)[k-1]


def select_heapq_nlargest(L, k):
    return heapq.nlargest(k, L)[-1]


def select_median_of_medians(L, k):
    nL = len(L)
    if nL < 50:
        return sorted(L, reverse=True)[k-1]

    R = [sorted(L[i*5:(i+1)*5])[2] for i in range(int(nL/5))]
    m = select_median_of_medians(R, int(nL/10))

    L1 = [v for v in L if v > m]
    L2 = [v for v in L if v == m]
    L3 = [v for v in L if v < m]

    if k <= len(L1):
        return select_median_of_medians(L1, k)
    elif k <= len(L1)+len(L2):
        return m
    else:
        return select_median_of_medians(L3, k-len(L1)-len(L2))


if __name__ == '__main__':
    N, M = 1, 17
    my_range = range(N, M)
    ndatas = [2**i for i in my_range]
    ndatas_to_show = ['2^{}'.format(i) for i in my_range]
    times_sorted = [0] * len(ndatas)
    times_heapq_nlargest = [0] * len(ndatas)
    times_median_of_medians = [0] * len(ndatas)

    all_samples = random.sample(range(20**6), max(ndatas))

    for i, n in enumerate(ndatas):
        samples = all_samples[:n]
        k = int(n/2)

        # sorted()
        start = time.time()
        kth_number = select_sorted(samples, k)
        times_sorted[i] = time.time() - start

        # heapq.nlargest()
        start = time.time()
        kth_number = select_heapq_nlargest(samples, k)
        times_heapq_nlargest[i] = time.time() - start

        # median_of_median()
        start = time.time()
        kth_number = select_median_of_medians(samples, k)
        times_median_of_medians[i] = time.time() - start


    plot(ndatas, times_sorted, 'o-')
    plot(ndatas, times_heapq_nlargest, 'x-')
    plot(ndatas, times_median_of_medians, '^-')
    xticks(ndatas, ndatas_to_show)

    legend(('sorted()', 'heapq.nlargest()', 'median_of_medians()'), 'best')
    title('Selection Algorithm')
    xlabel('number of data')
    ylabel('time [s]')
    grid()
    show()

結果

f:id:matsulib:20141202023000p:plain:w500
f:id:matsulib:20141202021807p:plain:w500
f:id:matsulib:20141202021827p:plain:w500

ちくしょう、ソートくそ速えな…(´・ω・`)
でもナイーブな実装の割には善戦したとも言えるかも。
heapq.nlargest()については公式ドキュメントでsorted()の方が速いと書いてあった。
8.4. heapq — ヒープキューアルゴリズム — Python 2.7ja1 documentation