matplotlibでtransformを指定したpatchをadd_patchすると親axesの座標系を無視する

matplotlibでpatchにtransformをつけると消えてしまうように見える現象が起きた。
まず、散布図の上に普通に長方形のpatchを乗せてみる。

clf()
scatter(random(100), random(100))
ax = gca()
rect = Rectangle((0, 0), 1, 1, alpha=0.3)
ax.add_patch(rect)
matplotlib.pyplot.savefig("test.png", dpi=100)

ここまではとても自然にできる。じゃあこの長方形を回転してみよう、とtransformを指定すると消えてしまう。transformの作り方がまずいのかと思ったがAffine2D.identity()でも消えてしまう。

clf()
scatter(random(100), random(100))
ax = gca()
t = Affine2D.identity()
rect = Rectangle((0, 0), 1, 1, transform=t, alpha=0.3)
ax.add_patch(rect)
matplotlib.pyplot.savefig("test.png", dpi=100)

問題解決の鍵は前回(matplotlibのPatchCollectionは子パッチの色を上書きする)同様、patch自体ではなくその上にあった。

Definition: ax.add_patch(self, p)
Docstring:
Add a :class:`~matplotlib.patches.Patch` *p* to the list of
axes patches; the clipbox will be set to the Axes clipping
box.  If the transform is not set, it will be set to
:attr:`transData`.

つまりpatchにtransformを指定してadd_patchすると親axesの座標系を無視してしまう。っていうわけで今回の例で言えばtransformを指定しないときには散布図の軸に従った座標系で0〜1の範囲にあったものが、単位行列をtransformに指定したことで図のグローバル座標系になってしまっているわけだ。そして散布図のクリッピングの外なので消えてしまっている。長方形のサイズを100x100にしてみるとちょこっと頭を出す:「やあ、僕はここだよ!」

rect = Rectangle((0, 0), 100, 100, transform=t, alpha=0.3)

じゃあax.transDataをtransformに指定すればいいわけね。

t = ax.transData
rect = Rectangle((0, 0), 1, 1, transform=t, alpha=0.3)

コレは期待通りの結果だったので省略。じゃあ早速回転してみよう。

for i in range(5):
    t = CompositeGenericTransform(Affine2D.identity().rotate(0.1 * i), ax.transData)
    rect = Rectangle((0, 0), 1, 1, transform=t, alpha=0.3)
    ax.add_patch(rect)

できたできた。めでたしめでたし。でも、本当にCompositeGenericTransformなんか使わないといけないのだろうか。もっと楽な方法があるんじゃないかなぁ。

追記。とりあえずidentityは使わなくてもAffine2D()で単位行列になるみたい。

matplotlibで線形分類器の分離平面を描くには


これをどうやって描いたかという話

まず矢印が変な形になるのを避けるために散布図の縦と横の縮尺を固定し(1)、重みベクトルで矢印を描き(2)、大きな灰色の長方形をほぼ透明にして重みベクトルの値を元に作ったアフィン変換で回転して描画している(3)

data1, data2, w, data2colorなんかはグローバルスコープから拾ってきているので気にしないように。

def draw(figureid):
    clf()

    ax = gca()
    ax.set_ylim(-6, 6) # (1)
    ax.set_xlim(-6, 6) # (1)

    scatter(data1[:, 0], data1[:, 1], edgecolors='blue', facecolors=data2color(data1))
    scatter(data2[:, 0], data2[:, 1], edgecolors='red', facecolors=data2color(data2))

    # (2)
    arr = YAArrow(fig, w, (0, 0), alpha=0.8, width=0.2, 
                  headwidth=0.6, frac=0.3, facecolor='red')
    ax.add_patch(arr)

    # (3)
    x, y = w
    t = CompositeGenericTransform(Affine2D.from_values(x, y, y, -x, 0, 0), ax.transData)
    rect = Rectangle((0, -100), 100, 200, transform=t, alpha=0.1, facecolor='grey')
    ax.add_patch(rect)

    ax.text(0.05, 0.95, str(figureid), transform=ax.transAxes,
            fontsize=16, fontweight='bold', va='top')

    matplotlib.pyplot.savefig("lr%04d.png" % figureid, dpi=50)