-Python- Matplotlib 等高線の描画

Matplotlibを用いて等高線を描画してみます.

 x^2 + \cfrac{y^2}{4} = k

のkをいくつか変えて曲線を描くことを考えます.上記の方程式はkを固定すると楕円になります.ここでは k = 1, 2, 3, 4, 5 のときの曲線を描いてみます.
スクリプト例を示します.

 
# Import Module
import matplotlib.pyplot as plt
import numpy as np
 
def f(x, y): 
    return x ** 2 + y ** 2 / 4
 
x = np.linspace(-5, 5, 300)
y = np.linspace(-5, 5, 300)
xmesh, ymesh = np.meshgrid(x, y)
z = f(xmesh.ravel(), ymesh.ravel()).reshape(xmesh.shape)
 
plt.contour(x, y, z, colors = 'k', levels = [1, 2, 3, 4, 5])
 
plt.show()
 

実行結果は以下のようになります.

f:id:HidehikoMURAO:20181101155323p:plain

等高線を描くには平面を細かいグリッド(格子)に区切って,それぞれの点で関数の値を評価する必要があるのですが,meshgrid にはそれを補助するための関数です.
話を簡単にするために x座標の値として1,2,3を,y座標の値として4, 5, 6を考え,その値に対するグリッド上の9つの点の座標を考えます.REPL上で例示すると以下のようになります.
>>> import numpy as np
>>> x = np.array([1, 2, 3])
>>> y = np.array([4, 5, 6])
>>> xmesh, ymesh = np.meshgrid(x, y)
>>> xmesh
array([[1, 2, 3],
       [1, 2, 3],
       [1, 2, 3]])
>>> ymesh
array([[4, 4, 4],
       [5, 5, 5],
       [6, 6, 6]])
>>> 
 
これらの点の座標を計算するのがmeshgridです.変数として1, 2, 3からなる配列を,変数yとして4, 5, 6からなる配列を用意し,そのxとyを引数としてmeshgridを呼び出すと,2つの2次元配列が返ってきます.それらにはそれぞれグリッド点のx座標とy座標が並んでいます.
 
上記のプログラムでは,8-10行目ではxとyに[-5, 5]区間を細かく刻んだものを入れ,それをもとにmeshgridを呼んでいます.
xとyは2次元配列ですが,11行目ではそれについて関数を評価しています.関数fの中身は四則演算とべき乗だけなので,全てブロードキャストが効いて,2つの引数が2次元配列でも計算でき,その結果も2次元配列になります.一般にはここで評価する関数は2次元配列に対応しているとは限りません(機械学習関連の関数は1次元配列にしか対応していないことも多々あります).そのような場合はravelメソッドなどを使って一度2次元配列に直してから評価し,その後reshapeで2次元配列に変換します.11行目の実行が終わった時点では,xmeshとymeshにはそれぞれグリッド点のx座標とy座標がそれぞれ入っており,zにはグリッド点上の関数の評価値が入ります.
xmesh,ymesh,zはすべて2次元配列で同じ形状になっています.
13(下から2)行目ではこれらを引数にして等高線を描いています.引数levelsでは等高線を描く関数値を指定しています.つまり,この場合は関数fの評価値が1, 2, 3, 4, 5のところで等高線を描きます.
関数contourは等高線の線だけを描画しますが,関数の評価値によって領域を塗り分けるにはcountoutを用います.
 

import matplotlib.pyplot as plt

import numpy as np

 

 

def f(x, y):

    return x**2 + y**2 / 4

 

 

x = np.linspace(-5, 5, 300)

y = np.linspace(-5, 5, 300)

xmesh, ymesh = np.meshgrid(x, y)

z = f(xmesh, ymesh)

 

colors = ["0.1", "0.3", "0.5", "0.7"]

levels = [1, 2, 3, 4, 5]

plt.contourf(x, y, z, colors=colors, levels=levels)

plt.show()

f:id:HidehikoMURAO:20181101224757p:plain