matsulibの日記

自分用メモ

Pythonで関数プログラミング 珠玉のアルゴリズムデザイン 第19章「単純な数独ソルバー」

最近読み始めた本。

www.amazon.co.jp

どの章からでも読めるらしいのでとりあえず理解できそうな数独から。
紹介されているアルゴリズムはおそらくタイトルの通り単純なんだろうが、Haskellのサンプルコードで思考停止した…
そこで両言語の勉強がてらHaskellのサンプルコードをPythonに移植してみた。

Pythonと比較してHaskellについて思ったこと:

  • 型システムすごい
  • パターンマッチすごい
  • Python「lambda x: x != 1」 Haskell「(/= 1)」 中置関数とかカリー化とか部分適用とかすごい
  • 関数合成で括弧がかなり減らせて格好いい

格好いいけどムズい。使いこなせてる人は頭が良いんだろうな。

以下、書いたPythonコード。なんか違和感があるけど気にしない。

from itertools import chain
concat = lambda xs: list(chain.from_iterable(xs))

digits = list(range(1, 10))
blank = lambda x: x == 0

def choices(grid):
    choice = lambda d: digits if blank(d) else [d]

    return [[choice(d) for d in row] for row in grid]

def nodups(xs):
    if xs == []:
        return True
    else:
        y, *ys = xs
        return all(map(lambda x: x != y, ys)) and nodups(ys)

def rows(grid):
    return grid

def cols(grid):
    return zip(*grid)

def boxs(grid):
    return list(map(ungroup, ungroup(list(map(cols, group(list(map(group, grid))))))))

def group(xs):
    xs = list(xs)

    if xs==[]:
        return []
    else:
        return [xs[:3]] + group(xs[3:])

def ungroup(xss):
    return concat(xss)

def prune(matrix_choices):
    return pruneBy(boxs, (pruneBy(cols, pruneBy(rows, matrix_choices))))

def pruneBy(f, matrix):
    return f(map(pruneRow, f(matrix)))

def pruneRow(row_choices):
    fixed = [d[0] for d in row_choices if len(d)==1]
    return map(lambda ds: remove(fixed, ds), row_choices)

def remove(xs, ds):
    if single(ds):
        return ds
    else:
        return [d for d in ds if d not in xs]

def single(ds):
    return len(ds) == 1

def expand1(rows):
    n = min(counts(rows))
    smallest = lambda cs: len(cs) == n
    rows1, (row, *rows2) = _break(lambda xs: any(map(smallest, xs)), rows)
    row1, (cs, *row2) = _break(smallest, row)

    return [rows1 + [row1 + [[c]] + row2]  + rows2 for c in cs]

def _break(predicate, xs):
    if len(xs) == 0:
        return [], xs

    for i in range(len(xs)):
        if predicate(xs[i]):
            return xs[:i], xs[i:]

def counts(xss):
    return filter(lambda x: x != 1, (map(len, concat(xss))))

def complete(matrix):
    return all([all(map(single, m)) for m in matrix])

def safe(matrix):
    return all(map(ok, rows(matrix))) \
            and all(map(ok, cols(matrix))) \
            and all(map(ok, boxs(matrix)))
            
def ok(row):
    return nodups([d[0] for d in row if len(d)==1])

def search(matrix):
    ms = prune(matrix)

    if not safe(matrix):
        return []
    elif complete(ms):
        return [[[r[0] for r in row] for row in ms]]
    else:
        return concat(map(search, expand1(ms)))

def solve(grid):
    return search(choices(grid))


def test1():
    import time
    s = time.time()

    m = [[0, 5, 0, 0, 6, 0, 0, 0, 1],
        [0, 0, 4, 8, 0, 0, 0, 7, 0],
        [8, 0, 0, 0, 0, 0, 0, 5, 2],
        [2, 0, 0, 0, 5, 7, 0, 3, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 3, 0, 6, 9, 0, 0, 0, 5],
        [7, 9, 0, 0, 0, 0, 0, 0, 8],
        [0, 1, 0, 0, 0, 6, 5, 0, 0],
        [5, 0, 0, 0, 3, 0, 0, 6, 0]]

    print(solve(m))
    print("%.3fs" % (time.time() - s))

def test2():
    import time
    s = time.time()

    m = [[0, 0, 0, 0, 6, 0, 0, 8, 0],
        [0, 2, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 1, 0, 0, 0, 0, 0, 0],
        [0, 7, 0, 0, 0, 0, 1, 0, 2],
        [5, 0, 0, 0, 3, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 4, 0, 0],
        [0, 0, 4, 2, 0, 1, 0, 0, 0],
        [3, 0, 0, 7, 0, 0, 6, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 5, 0]]

    print(len(solve(m)))
    print("%.3fs" % (time.time() - s))


if __name__ == '__main__':
    test1()
    # [[[9, 5, 3, 7, 6, 2, 8, 4, 1], [6, 2, 4, 8, 1, 5, 9, 7, 3], [8, 7, 1, 3, 4, 9, 6, 5, 2], [2, 8, 9, 4, 5, 7, 1, 3, 6], [1, 6, 5, 2, 8, 3, 4, 9, 7], [4, 3, 7, 6, 9, 1, 2, 8, 5], [7, 9, 6, 5, 2, 4, 3, 1, 8], [3, 1, 8, 9, 7, 6, 5, 2, 4], [5, 4, 2, 1, 3, 8, 7, 6, 9]]]
    # 0.433s

    test2()
    # 33
    # 666.855s

2つの問題は https://wiki.haskell.org/Sudoku#Test_boards から。

2つ目の問題は解が33個あって全部探すのに時間がかかった。ちなみに同じ環境でPyPyだと300秒、Haskellだと移植元のコードで95秒だった。