NumPy+Matplotlibで散布図の上に平均と分散を表示する

先日「NumPyで散布図を書く」ではscatterを使って散布図を描くところまでをやった。今日はこの上に分散と平均を示す楕円を描き足したい。

PatchCollectionのサンプルを見てみるとpylab.figureでmatplotlib.figure.Figureを作って、fig.add_subplotでmatplotlib.axes.AxesSubplotを作って、そこにPatchCollectionをadd_collectionするという流れになっている。ref. api example code: patch_collection.py — Matplotlib v1.1.0 documentation

が、しかし、今したいのはscatterで作られた散布図の上に描き足すことであって、pylab.figure()すると新しい図を作ってしまうからダメ。最初からpylab.figure()しておいてそこにscatterを描くのなら行けそうだけど、人間は得てして事前の準備を忘れるもんなので、後から何とかしたい。そこでgcaを使う。

ax = gca()

「Return the current axis instance.」だそうな。というわけで散布図を描いたあとでaxisインスタンスを取得して、描きたい楕円のインスタンスを作って適当な位置にtransformし、add_collectionしてからsavefigすれば出来上がり。ref. Working with transformations — Matplotlib v1.1.0 documentation

def draw():
    clf()
    scatter(data[:, 0], data[:, 1], alpha=0.5, marker="+")
    ax = gca()
    circles = [
        Circle(
            (0, 0), radius=3, 
            transform=to_transform(mu, sigma)) for mu, sigma in zip(mus, sigmas)]
    ax.add_collection(
        PatchCollection(circles, alpha=0.2))
    matplotlib.pyplot.savefig('test.png')

このtransformを作る所に関しては自作のto_transformって関数にまとめておいた。共分散行列をそのまま使ったのでは円の半径が分散に比例してしまうので、そうではなく標準偏差に比例させるために一度固有値分解してから固有値平方根をとってある。これと、上のCircleの引数radius=3とで「各軸で3標準偏差の点を通る等高線を引く」を実現しているつもり。

def to_transform(mu, sigma):
    val, vec = eigh(sigma)
    trans = diag(sqrt(val)).dot(vec)
    return Affine2D.from_values(*trans.flatten(), e=mu[0], f=mu[1])

ところでCircleの引数でalpha=0.5とかfill=Falseとか指定できるとドキュメントには書いてあるのに、指定しても反映されないのはなぜなんだろう。本当はPRMLのfig 9.8みたいなのを作りたいんだが。

追記: PatchCollectionのmatch_originalをTrueにしないと色が上書きされてしまうことがわかった http://d.hatena.ne.jp/nishiohirokazu/20111111/1320994184