Pythonでリストからn番目に大きな要素を取り出す
最近はずっとプログラミングできてなかったけど、ふと前に授業で使ってたノートを見たら「縮小法」なるアルゴリズムが載っていた。これはリストからn番目に大きな要素を取り出すアルゴリズムで、英語版wikipediaではmedian of mediansという項目に詳しく書かれている。
Median of medians - Wikipedia, the free encyclopedia
結論から言えばO(n)で計算できるからソートを使った方法O(nlogn)より計算量が少ないという。
そう言えばそういうのあったなあ、という感覚だけどせっかくだからリハビリがてら実装してみることにした。
同じことをやるアルゴリズムは英語版wikipediaの
Selection algorithm - Wikipedia, the free encyclopedia
に色々と書かれているけど、最後の方にPythonは標準ライブラリにheapq.nlargest()という関数があって、それを使うとO(n log k) で計算できるらしいのでついでに試してみた。
参考資料
[PPT]中京大、白井英俊先生
http://www.cyber.sist.chukyo-u.ac.jp/classes/algo/PPT/Chap9.ppt
ソースコード
from pylab import * import time import random import heapq def select_sorted(L, k): return sorted(L, reverse=True)[k-1] def select_heapq_nlargest(L, k): return heapq.nlargest(k, L)[-1] def select_median_of_medians(L, k): nL = len(L) if nL < 50: return sorted(L, reverse=True)[k-1] R = [sorted(L[i*5:(i+1)*5])[2] for i in range(int(nL/5))] m = select_median_of_medians(R, int(nL/10)) L1 = [v for v in L if v > m] L2 = [v for v in L if v == m] L3 = [v for v in L if v < m] if k <= len(L1): return select_median_of_medians(L1, k) elif k <= len(L1)+len(L2): return m else: return select_median_of_medians(L3, k-len(L1)-len(L2)) if __name__ == '__main__': N, M = 1, 17 my_range = range(N, M) ndatas = [2**i for i in my_range] ndatas_to_show = ['2^{}'.format(i) for i in my_range] times_sorted = [0] * len(ndatas) times_heapq_nlargest = [0] * len(ndatas) times_median_of_medians = [0] * len(ndatas) all_samples = random.sample(range(20**6), max(ndatas)) for i, n in enumerate(ndatas): samples = all_samples[:n] k = int(n/2) # sorted() start = time.time() kth_number = select_sorted(samples, k) times_sorted[i] = time.time() - start # heapq.nlargest() start = time.time() kth_number = select_heapq_nlargest(samples, k) times_heapq_nlargest[i] = time.time() - start # median_of_median() start = time.time() kth_number = select_median_of_medians(samples, k) times_median_of_medians[i] = time.time() - start plot(ndatas, times_sorted, 'o-') plot(ndatas, times_heapq_nlargest, 'x-') plot(ndatas, times_median_of_medians, '^-') xticks(ndatas, ndatas_to_show) legend(('sorted()', 'heapq.nlargest()', 'median_of_medians()'), 'best') title('Selection Algorithm') xlabel('number of data') ylabel('time [s]') grid() show()
結果
ちくしょう、ソートくそ速えな…(´・ω・`)
でもナイーブな実装の割には善戦したとも言えるかも。
heapq.nlargest()については公式ドキュメントでsorted()の方が速いと書いてあった。
8.4. heapq — ヒープキューアルゴリズム — Python 2.7ja1 documentation