Pythonで関数プログラミング 珠玉のアルゴリズムデザイン 第19章「単純な数独ソルバー」
最近読み始めた本。
どの章からでも読めるらしいのでとりあえず理解できそうな数独から。
紹介されているアルゴリズムはおそらくタイトルの通り単純なんだろうが、Haskellのサンプルコードで思考停止した…
そこで両言語の勉強がてらHaskellのサンプルコードをPythonに移植してみた。
- 型システムすごい
- パターンマッチすごい
- 関数合成で括弧がかなり減らせて格好いい
- printデバッグの方法がよく分からない
格好いいけどムズい。使いこなせてる人は頭が良いんだろうな。
以下、書いたPythonコード。なんか違和感があるけど気にしない。
from itertools import chain concat = lambda xs: list(chain.from_iterable(xs)) digits = list(range(1, 10)) blank = lambda x: x == 0 def choices(grid): choice = lambda d: digits if blank(d) else [d] return [[choice(d) for d in row] for row in grid] def nodups(xs): if xs == []: return True else: y, *ys = xs return all(map(lambda x: x != y, ys)) and nodups(ys) def rows(grid): return grid def cols(grid): return zip(*grid) def boxs(grid): return list(map(ungroup, ungroup(list(map(cols, group(list(map(group, grid)))))))) def group(xs): xs = list(xs) if xs==[]: return [] else: return [xs[:3]] + group(xs[3:]) def ungroup(xss): return concat(xss) def prune(matrix_choices): return pruneBy(boxs, (pruneBy(cols, pruneBy(rows, matrix_choices)))) def pruneBy(f, matrix): return f(map(pruneRow, f(matrix))) def pruneRow(row_choices): fixed = [d[0] for d in row_choices if len(d)==1] return map(lambda ds: remove(fixed, ds), row_choices) def remove(xs, ds): if single(ds): return ds else: return [d for d in ds if d not in xs] def single(ds): return len(ds) == 1 def expand1(rows): n = min(counts(rows)) smallest = lambda cs: len(cs) == n rows1, (row, *rows2) = _break(lambda xs: any(map(smallest, xs)), rows) row1, (cs, *row2) = _break(smallest, row) return [rows1 + [row1 + [[c]] + row2] + rows2 for c in cs] def _break(predicate, xs): if len(xs) == 0: return [], xs for i in range(len(xs)): if predicate(xs[i]): return xs[:i], xs[i:] def counts(xss): return filter(lambda x: x != 1, (map(len, concat(xss)))) def complete(matrix): return all([all(map(single, m)) for m in matrix]) def safe(matrix): return all(map(ok, rows(matrix))) \ and all(map(ok, cols(matrix))) \ and all(map(ok, boxs(matrix))) def ok(row): return nodups([d[0] for d in row if len(d)==1]) def search(matrix): ms = prune(matrix) if not safe(matrix): return [] elif complete(ms): return [[[r[0] for r in row] for row in ms]] else: return concat(map(search, expand1(ms))) def solve(grid): return search(choices(grid)) def test1(): import time s = time.time() m = [[0, 5, 0, 0, 6, 0, 0, 0, 1], [0, 0, 4, 8, 0, 0, 0, 7, 0], [8, 0, 0, 0, 0, 0, 0, 5, 2], [2, 0, 0, 0, 5, 7, 0, 3, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 3, 0, 6, 9, 0, 0, 0, 5], [7, 9, 0, 0, 0, 0, 0, 0, 8], [0, 1, 0, 0, 0, 6, 5, 0, 0], [5, 0, 0, 0, 3, 0, 0, 6, 0]] print(solve(m)) print("%.3fs" % (time.time() - s)) def test2(): import time s = time.time() m = [[0, 0, 0, 0, 6, 0, 0, 8, 0], [0, 2, 0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0, 0], [0, 7, 0, 0, 0, 0, 1, 0, 2], [5, 0, 0, 0, 3, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 4, 0, 0], [0, 0, 4, 2, 0, 1, 0, 0, 0], [3, 0, 0, 7, 0, 0, 6, 0, 0], [0, 0, 0, 0, 0, 0, 0, 5, 0]] print(len(solve(m))) print("%.3fs" % (time.time() - s)) if __name__ == '__main__': test1() # [[[9, 5, 3, 7, 6, 2, 8, 4, 1], [6, 2, 4, 8, 1, 5, 9, 7, 3], [8, 7, 1, 3, 4, 9, 6, 5, 2], [2, 8, 9, 4, 5, 7, 1, 3, 6], [1, 6, 5, 2, 8, 3, 4, 9, 7], [4, 3, 7, 6, 9, 1, 2, 8, 5], [7, 9, 6, 5, 2, 4, 3, 1, 8], [3, 1, 8, 9, 7, 6, 5, 2, 4], [5, 4, 2, 1, 3, 8, 7, 6, 9]]] # 0.433s test2() # 33 # 666.855s
2つの問題は https://wiki.haskell.org/Sudoku#Test_boards から。
2つ目の問題は解が33個あって全部探すのに時間がかかった。ちなみに同じ環境でPyPyだと300秒、Haskellだと移植元のコードで95秒だった。