matplotlibで線形分類器の分離平面を描くには


これをどうやって描いたかという話

まず矢印が変な形になるのを避けるために散布図の縦と横の縮尺を固定し(1)、重みベクトルで矢印を描き(2)、大きな灰色の長方形をほぼ透明にして重みベクトルの値を元に作ったアフィン変換で回転して描画している(3)

data1, data2, w, data2colorなんかはグローバルスコープから拾ってきているので気にしないように。

def draw(figureid):
    clf()

    ax = gca()
    ax.set_ylim(-6, 6) # (1)
    ax.set_xlim(-6, 6) # (1)

    scatter(data1[:, 0], data1[:, 1], edgecolors='blue', facecolors=data2color(data1))
    scatter(data2[:, 0], data2[:, 1], edgecolors='red', facecolors=data2color(data2))

    # (2)
    arr = YAArrow(fig, w, (0, 0), alpha=0.8, width=0.2, 
                  headwidth=0.6, frac=0.3, facecolor='red')
    ax.add_patch(arr)

    # (3)
    x, y = w
    t = CompositeGenericTransform(Affine2D.from_values(x, y, y, -x, 0, 0), ax.transData)
    rect = Rectangle((0, -100), 100, 200, transform=t, alpha=0.1, facecolor='grey')
    ax.add_patch(rect)

    ax.text(0.05, 0.95, str(figureid), transform=ax.transAxes,
            fontsize=16, fontweight='bold', va='top')

    matplotlib.pyplot.savefig("lr%04d.png" % figureid, dpi=50)