JAX (ライブラリ)
JAXは、高速な数値計算と大規模な機械学習のために設計されたPythonのオープンソースのライブラリ[6]。NumPy風の構文で書かれたPythonのソースコードをCPU・GPU・AIアクセラレータ[7]へコンパイルする実行時コンパイラや自動微分などを含む。 実行時コンパイラは、JAXからOpenXLAのXLAにコンパイルし、そこから先はハードウェア次第だが、多くのCPUとGPUはLLVMを経由してコンパイルされる[8]。 基本的な使用方法下記のソースコードのように、関数に @jit を付けることにより、その部分が実行時コンパイルされる。同一のソースコードで、CPUだけでなく、GPUやAIアクセラレータでも動作させることが可能である。詳細は後述するが、@jitの中に書けるのは普通のPythonのプログラムではなく、Pythonの構文を使用した純粋関数型言語である。 import jax.numpy as jnp
from jax import jit
@jit
def f(a, b):
return a + b
x = jnp.array([1, 2, 3], dtype=jnp.float32)
print(f(x, x))
map を自動ベクトル化した vmap があり、 from jax import jit, vmap
@jit
def f(a):
return vmap(lambda x: x * 2)(a)
Numbaとの違い似たようなライブラリとしてNumbaがあるが、以下の違いがある。純粋関数型にすることにより色々な最適化がかかっている。関数型言語としての分類は、純粋、正格評価、型を明示する必要が無い静的型付けである。
純粋関数型であるため、乱数を使用する際に、下記のように、乱数生成のキーを明示的に作り直さないといけない。[15] key, subkey = jax.random.split(key)
x = jax.random.normal(subkey)
配列を書き換える際は、手続き型では if文とmatch文JAXではPythonのif文とmatch文は基本的にはそのままでは使用できない。下記が用意されている。
while文とfor文JAXではPythonのwhile文とfor文は基本的にはそのままでは使用できず、ループ回数が定数の場合でPythonのfor文をそのまま使用した場合は、ループアンロールされる。[21] ループ構造を作るものとして下記が用意されている。
純粋関数型のため、scan, fori_loop, while_loop は全て前の計算結果を次に渡すという形となっている。 自動微分jax.grad にて自動微分できる。例えば、最急降下法は下記で実装できる。init_x から始めて、fori_loop にて iter_count 回、計算を反復している。 が最小となるx、つまり1を求めている。 from jax import jit, grad
from jax.lax import fori_loop
f = lambda x: (x - 1) ** 2
@jit
def gradient_descent(init_x, iter_count, learn_rate):
return fori_loop(0, iter_count, lambda i, x: x - learn_rate * grad(f)(x), init_x)
print(gradient_descent(0.0, 30, 0.3))
参照
関連項目外部リンク |
Portal di Ensiklopedia Dunia