Pythonで関数プログラミング 珠玉のアルゴリズムデザイン 第19章「単純な数独ソルバー」
- 型システムすごい
- パターンマッチすごい
- 関数合成で括弧がかなり減らせて格好いい
- printデバッグの方法がよく分からない
from itertools import chain concat = lambda xs: list(chain.from_iterable(xs)) digits = list(range(1, 10)) blank = lambda x: x == 0 def choices(grid): choice = lambda d: digits if blank(d) else [d] return [[choice(d) for d in row] for row in grid] def nodups(xs): if xs == []: return True else: y, *ys = xs return all(map(lambda x: x != y, ys)) and nodups(ys) def rows(grid): return grid def cols(grid): return zip(*grid) def boxs(grid): return list(map(ungroup, ungroup(list(map(cols, group(list(map(group, grid)))))))) def group(xs): xs = list(xs) if xs==[]: return [] else: return [xs[:3]] + group(xs[3:]) def ungroup(xss): return concat(xss) def prune(matrix_choices): return pruneBy(boxs, (pruneBy(cols, pruneBy(rows, matrix_choices)))) def pruneBy(f, matrix): return f(map(pruneRow, f(matrix))) def pruneRow(row_choices): fixed = [d[0] for d in row_choices if len(d)==1] return map(lambda ds: remove(fixed, ds), row_choices) def remove(xs, ds): if single(ds): return ds else: return [d for d in ds if d not in xs] def single(ds): return len(ds) == 1 def expand1(rows): n = min(counts(rows)) smallest = lambda cs: len(cs) == n rows1, (row, *rows2) = _break(lambda xs: any(map(smallest, xs)), rows) row1, (cs, *row2) = _break(smallest, row) return [rows1 + [row1 + [[c]] + row2] + rows2 for c in cs] def _break(predicate, xs): if len(xs) == 0: return [], xs for i in range(len(xs)): if predicate(xs[i]): return xs[:i], xs[i:] def counts(xss): return filter(lambda x: x != 1, (map(len, concat(xss)))) def complete(matrix): return all([all(map(single, m)) for m in matrix]) def safe(matrix): return all(map(ok, rows(matrix))) \ and all(map(ok, cols(matrix))) \ and all(map(ok, boxs(matrix))) def ok(row): return nodups([d[0] for d in row if len(d)==1]) def search(matrix): ms = prune(matrix) if not safe(matrix): return [] elif complete(ms): return [[[r[0] for r in row] for row in ms]] else: return concat(map(search, expand1(ms))) def solve(grid): return search(choices(grid)) def test1(): import time s = time.time() m = [[0, 5, 0, 0, 6, 0, 0, 0, 1], [0, 0, 4, 8, 0, 0, 0, 7, 0], [8, 0, 0, 0, 0, 0, 0, 5, 2], [2, 0, 0, 0, 5, 7, 0, 3, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 3, 0, 6, 9, 0, 0, 0, 5], [7, 9, 0, 0, 0, 0, 0, 0, 8], [0, 1, 0, 0, 0, 6, 5, 0, 0], [5, 0, 0, 0, 3, 0, 0, 6, 0]] print(solve(m)) print("%.3fs" % (time.time() - s)) def test2(): import time s = time.time() m = [[0, 0, 0, 0, 6, 0, 0, 8, 0], [0, 2, 0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0, 0], [0, 7, 0, 0, 0, 0, 1, 0, 2], [5, 0, 0, 0, 3, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 4, 0, 0], [0, 0, 4, 2, 0, 1, 0, 0, 0], [3, 0, 0, 7, 0, 0, 6, 0, 0], [0, 0, 0, 0, 0, 0, 0, 5, 0]] print(len(solve(m))) print("%.3fs" % (time.time() - s)) if __name__ == '__main__': test1() # [[[9, 5, 3, 7, 6, 2, 8, 4, 1], [6, 2, 4, 8, 1, 5, 9, 7, 3], [8, 7, 1, 3, 4, 9, 6, 5, 2], [2, 8, 9, 4, 5, 7, 1, 3, 6], [1, 6, 5, 2, 8, 3, 4, 9, 7], [4, 3, 7, 6, 9, 1, 2, 8, 5], [7, 9, 6, 5, 2, 4, 3, 1, 8], [3, 1, 8, 9, 7, 6, 5, 2, 4], [5, 4, 2, 1, 3, 8, 7, 6, 9]]] # 0.433s test2() # 33 # 666.855s
2つの問題は から。