Needleman-Wunsch algorithmを実装した

前に86チャットで誰かがdiffを作ろうとしていて、長さ N, M の文字列に対して素朴に N * M の配列を確保していたせいで大きなサイズの入力でメモリを食い過ぎて破綻していたときに「動的計画法で端から埋めていくんだから、直前の1列だけ取っておくだけでいいでしょ」「いや、後でパスを求めないと行けないから全部持つ必要があるんだ!」「ないよ!パスの根元だけ参照で持っておけばいらないパスはGCで消えるでしょ!」という話をしたんだが、今ちょっと自分でも必要になったので作ってみた。うん、思った通りに動くな。

テストケースに書いてあるけど、実行すると下のような入出力になる。

    >>> test("1", "001")
    in1:  1
    in2:  001
    (((), (0, 2)), 'D')
    out1: --1
    out2: 001

    >>> test("1", "100")
    in1:  1
    in2:  100
    (((((), (0, 0)), 'D'), 'V'), 'V')
    out1: 1--
    out2: 100

    >>> test("1", "00100")
    in1:  1
    in2:  00100
    (((((), (0, 2)), 'D'), 'V'), 'V')
    out1: --1--
    out2: 00100

    >>> test("11", "1001")
    in1:  11
    in2:  1001
    ((((((), (0, 0)), 'D'), 'V'), 'V'), 'D')
    out1: 1--1
    out2: 1001

    >>> data1 = [1,2,3,5,7,9,0]
    >>> data2 = [2,3,4,5,6,7,8,0]
    >>> test(data1, data2)
    in1:  [1, 2, 3, 5, 7, 9]
    in2:  [2, 3, 4, 5, 6, 7, 8]
    (((((((((), (1, 0)), 'D'), 'D'), 'V'), 'D'), 'V'), 'D'), 'D')
    out1: 123-5-79
    out2: -2345678

僕がやりたいことは2つのよく似たHTMLのdiffなんだけど、普通の行単位のdiffを使ったらうまく行かない。文字単位のdiffで行けるかなと思ったけど、文中のLibraryのyがもう片方のbodyタグのyにマッチしたりしてとてもあほな結果になってしまう。せっかく「HTMLである」という入力の構造に関する知見があるのだからそれを利用しない手はない。というわけで正規表現で<〜〜>の形になっている部分となってない部分とでトークンにわけて、それでdiffしたら割とうまくいったんだけど、ちょっと期待に足りない。テキストノードの不一致は「違っていて当たり前」という入力なのであんまり重視したくなくて、ulタグとaタグをの違いは大きな違いと見なしてほしい。そこら編をカスタマイズしたかったのでスコアの計算関数をくくりだしてみた。今はまだ単純な比較になっている。

下のコードで文字列を食わせたり整数のリストを食わせたりしているのでわかるように、文字列かどうかは本質的ではない。ただ、今はランダムアクセスが必要になっているなー。ファイルの内容をメモリ上に持ってしまう。まぁ、僕が応用したい対象はギガ単位のサイズになったりしないのでいっかな。

"""
Needleman-Wunsch algorithm
"""

def calc_cost(x, y):
    """
    calc inverse similariry
    0: completely match, 1: completely different
    """
    if x == y:
        return 0
    return 1

GAP_COST = 1

def diff(s1, s2):
    W = len(s1) # widht
    H = len(s2) # height
    costs = [GAP_COST * y for y in range(H + 1)]
    paths = [((), (0, y)) for y in range(H + 1)]
    for x in range(1, W + 1):
        next_costs = [GAP_COST * x]
        next_paths = [((), (x, 0))] 
        for y in range(1, H + 1):
            # horizontal gap
            min_cost = costs[y] + GAP_COST
            min_link = (paths[y], "H")
            # vertical gap
            cost = next_costs[y - 1] + GAP_COST
            if cost < min_cost:
                min_cost = cost
                min_link = (next_paths[y - 1], "V")
            # diagonal
            cost = costs[y - 1] + calc_cost(s1[x - 1], s2[y - 1])
            if cost < min_cost:
                min_cost = cost
                min_link = (paths[y - 1], "D")

            next_costs.append(min_cost)
            next_paths.append(min_link)
        costs = next_costs
        paths = next_paths
    return paths[-1]

def test(data1, data2):
    """
    >>> test("1", "001")
    in1:  1
    in2:  001
    (((), (0, 2)), 'D')
    out1: --1
    out2: 001

    >>> test("1", "100")
    in1:  1
    in2:  100
    (((((), (0, 0)), 'D'), 'V'), 'V')
    out1: 1--
    out2: 100

    >>> test("1", "00100")
    in1:  1
    in2:  00100
    (((((), (0, 2)), 'D'), 'V'), 'V')
    out1: --1--
    out2: 00100

    >>> test("11", "1001")
    in1:  11
    in2:  1001
    ((((((), (0, 0)), 'D'), 'V'), 'V'), 'D')
    out1: 1--1
    out2: 1001

    >>> data1 = [1,2,3,5,7,9,0]
    >>> data2 = [2,3,4,5,6,7,8,0]
    >>> test(data1, data2)
    in1:  [1, 2, 3, 5, 7, 9]
    in2:  [2, 3, 4, 5, 6, 7, 8]
    (((((((((), (1, 0)), 'D'), 'D'), 'V'), 'D'), 'V'), 'D'), 'D')
    out1: 123-5-79
    out2: -2345678
    """
    print "in1: ", data1
    print "in2: ", data2 
    result = diff(data1, data2)
    print result
    buf1 = []
    buf2 = []
    x = len(data1) - 1
    y = len(data2) - 1
    while result:
        cur = result[1]
        next = result[0]
        if cur == "D":
            buf1.insert(0, str(data1[x]))
            x -= 1
            buf2.insert(0, str(data2[y]))
            y -= 1
        elif cur == "H":
            buf1.insert(0, str(data1[x]))
            x -= 1
            buf2.insert(0, "-")
        elif cur == "V":
            buf1.insert(0, "-")
            buf2.insert(0, str(data2[y]))
            y -= 1
        else:
            ox, oy = cur
            for i in range(ox):
                buf1.insert(0, str(data1[x]))
                x -= 1
                buf2.insert(0, "-")
                y -= 1
            for i in range(oy):
                buf1.insert(0, "-")
                x -= 1
                buf2.insert(0, str(data2[y]))
                y -= 1
                
        result = next

    print "out1:", "".join(buf1)
    print "out2:", "".join(buf2)

def _test():
    import doctest
    doctest.testmod()

if __name__ == "__main__":
    _test()