在JAX中编写自定义Jaxpr解释器#
JAX 提供了多个可组合的函数转换(jit
, grad
, vmap
等),使得编写简洁、高效的代码成为可能。
在这里,我们展示如何通过编写自定义的 Jaxpr 解释器将自己的函数转换添加到系统中。这样,我们就能免费获得与其他转换的可组合性。
此示例使用了内部 JAX API,可能随时会中断。未在 API 文档 中的内容应假定为内部使用。
import jax
import jax.numpy as jnp
from jax import jit, grad, vmap
from jax import random
JAX在做什么?#
JAX提供了类似NumPy的API用于数值计算,可以直接使用,但JAX的真正力量来自于可组合的函数转换。以jit
函数转换为例,它接收一个函数并返回一个语义上相同的函数,但该函数是由XLA为加速器延迟编译的。
x = random.normal(random.key(0), (5000, 5000))
def f(w, b, x):
return jnp.tanh(jnp.dot(x, w) + b)
fast_f = jit(f)
当我们调用 fast_f
时,会发生什么?JAX 追踪该函数并构建一个 XLA 计算图。然后,图会被 JIT 编译并执行。其他变换的工作方式相似,它们首先追踪该函数并以某种方式处理输出跟踪。要了解更多关于 JAX 的追踪机制的信息,可以查看 README 中的 “它是如何工作的” 部分。
Jaxpr 跟踪器#
在 Jax 中,一个特别重要的跟踪器是 Jaxpr 跟踪器,它将操作记录到 Jaxpr(Jax 表达式)中。 Jaxpr 是一种数据结构,可以像迷你函数式编程语言一样进行求值,因此 Jaxpr 是函数转换的一个有用的中间表示。
要初步了解Jaxprs,请考虑make_jaxpr
转换。 make_jaxpr
本质上是一种“美化打印”转换:它将一个函数转换为一个函数,该函数在给定示例参数时产生其计算的Jaxpr表示。 make_jaxpr
对于调试和反思非常有用。 让我们用它来看看一些示例Jaxprs是如何构造的。
def examine_jaxpr(closed_jaxpr):
jaxpr = closed_jaxpr.jaxpr
print("invars:", jaxpr.invars)
print("outvars:", jaxpr.outvars)
print("constvars:", jaxpr.constvars)
for eqn in jaxpr.eqns:
print("equation:", eqn.invars, eqn.primitive, eqn.outvars, eqn.params)
print()
print("jaxpr:", jaxpr)
def foo(x):
return x + 1
print("foo")
print("=====")
examine_jaxpr(jax.make_jaxpr(foo)(5))
print()
def bar(w, b, x):
return jnp.dot(w, x) + b + jnp.ones(5), x
print("bar")
print("=====")
examine_jaxpr(jax.make_jaxpr(bar)(jnp.ones((5, 10)), jnp.ones(5), jnp.ones(10)))
foo
=====
invars: [Var(id=4464342144):int32[]]
outvars: [Var(id=4464337536):int32[]]
constvars: []
equation: [Var(id=4464342144):int32[], 1] add [Var(id=4464337536):int32[]] {}
jaxpr: { lambda ; a:i32[]. let b:i32[] = add a 1 in (b,) }
bar
=====
invars: [Var(id=4464599616):float32[5,10], Var(id=4464599936):float32[5], Var(id=4464600000):float32[10]]
outvars: [Var(id=4464600256):float32[5], Var(id=4464600000):float32[10]]
constvars: []
equation: [Var(id=4464599616):float32[5,10], Var(id=4464600000):float32[10]] dot_general [Var(id=4464600064):float32[5]] {'dimension_numbers': (((1,), (0,)), ((), ())), 'precision': None, 'preferred_element_type': dtype('float32')}
equation: [Var(id=4464600064):float32[5], Var(id=4464599936):float32[5]] add [Var(id=4464600128):float32[5]] {}
equation: [1.0] broadcast_in_dim [Var(id=4464600192):float32[5]] {'shape': (5,), 'broadcast_dimensions': ()}
equation: [Var(id=4464600128):float32[5], Var(id=4464600192):float32[5]] add [Var(id=4464600256):float32[5]] {}
jaxpr: { lambda ; a:f32[5,10] b:f32[5] c:f32[10]. let
d:f32[5] = dot_general[
dimension_numbers=(([1], [0]), ([], []))
preferred_element_type=float32
] a c
e:f32[5] = add d b
f:f32[5] = broadcast_in_dim[broadcast_dimensions=() shape=(5,)] 1.0
g:f32[5] = add e f
in (g, c) }
jaxpr.invars
- Jaxpr 的invars
是输入变量的列表,类似于 Python 函数中的参数。jaxpr.outvars
- Jaxpr 的outvars
是 Jaxpr 返回的变量。每个 Jaxpr 有多个输出。jaxpr.constvars
-constvars
是一组变量,它们也是 Jaxpr 的输入,但对应于跟踪中的常量(我们稍后会详细讲解这些)。jaxpr.eqns
- 方程列表,本质上是让绑定。每个方程都是输入变量列表、输出变量列表和一个用于评估输入以产生输出的 原语。每个方程还有一个params
,它是参数的字典。
总之,Jaxpr 封装了一个简单的程序,可以用输入进行评估以生成输出。我们稍后会详细介绍如何做到这一点。现在需要注意的重要事项是,Jaxpr 是一个可以以我们想要的方式进行操作和评估的数据结构。
Jaxprs 为什么有用?#
Jaxprs 是简单的程序表示,易于转换。由于 Jax 允许我们将 Jaxprs 从 Python 函数中分离出来,这为我们提供了一种转换用 Python 编写的数值程序的方法。
你的第一个解释器:invert
#
让我们尝试实现一个简单的函数“inverter”,该函数接受原始函数的输出并返回产生这些输出的输入。现在,让我们专注于由其他可逆的单一函数组成的简单单一函数。
目标:
def f(x):
return jnp.exp(jnp.tanh(x))
f_inv = inverse(f)
assert jnp.allclose(f_inv(f(1.0)), 1.0)
我们将通过以下方式实现这一点:(1)将 f
跟踪到 Jaxpr 中,然后(2)反向解释 Jaxpr。在反向解释 Jaxpr 时,对于每个方程,我们将查找原语的逆在一个表中并应用它。
1. 跟踪一个函数#
让我们使用 make_jaxpr
将一个函数跟踪到 Jaxpr 中。
# 导入用于追踪/解释的Jax函数。
from functools import wraps
from jax import core
from jax import lax
from jax._src.util import safe_map
jax.make_jaxpr
返回一个闭合 Jaxpr,即已经与跟踪中的常量(literals
)捆绑在一起的 Jaxpr。
def f(x):
return jnp.exp(jnp.tanh(x))
closed_jaxpr = jax.make_jaxpr(f)(jnp.ones(5))
print(closed_jaxpr.jaxpr)
print(closed_jaxpr.literals)
{ lambda ; a:f32[5]. let b:f32[5] = tanh a; c:f32[5] = exp b in (c,) }
[]
2. 评估 Jaxpr#
在我们编写自定义 Jaxpr 解释器之前,首先实现 “默认” 解释器 eval_jaxpr
,该解释器将按原样评估 Jaxpr,计算与原始未转换的 Python 函数相同的值。
为此,我们首先创建一个环境来存储每个变量的值,并在评估 Jaxpr 中的每个方程时更新该环境。
def eval_jaxpr(jaxpr, consts, *args):
# 从变量到值的映射
env = {}
def read(var):
# 字面量是嵌入在 Jaxpr 中的值
if type(var) is core.Literal:
return var.val
return env[var]
def write(var, val):
env[var] = val
# 将参数和常量绑定到环境中
safe_map(write, jaxpr.invars, args)
safe_map(write, jaxpr.constvars, consts)
# 遍历方程并使用 `bind` 评估原语
for eqn in jaxpr.eqns:
# 从环境中读取方程的输入
invals = safe_map(read, eqn.invars)
# `bind` 是原始调用的方式
outvals = eqn.primitive.bind(*invals, **eqn.params)
# 原语可能会返回多个输出,也可能不会。
if not eqn.primitive.multiple_results:
outvals = [outvals]
# 将原始结果写入环境
safe_map(write, eqn.outvars, outvals)
# 从环境中读取Jaxpr的最终结果
return safe_map(read, jaxpr.outvars)
closed_jaxpr = jax.make_jaxpr(f)(jnp.ones(5))
eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, jnp.ones(5))
[Array([2.1416876, 2.1416876, 2.1416876, 2.1416876, 2.1416876], dtype=float32)]
请注意,eval_jaxpr
将始终返回一个扁平列表,即使原始函数并非如此。
此外,这个解释器不处理高阶原语(例如 jit
和 pmap
),这些我们在本指南中将不予以讨论。您可以参考 core.eval_jaxpr
(链接)以查看该解释器未涵盖的边缘案例。
自定义 inverse
Jaxpr 解释器#
一个 inverse
解释器与 eval_jaxpr
看起来没有太大区别。我们首先设置一个注册表,将原语映射到它们的逆操作。然后我们将编写一个自定义解释器,该解释器在注册表中查找原语。
事实证明,这个解释器也将类似于在反向模式自动微分中使用的“转置”解释器 链接在这里。
inverse_registry = {}
我们现在将为一些原语注册逆运算。根据惯例,Jax 中的原语以 _p
结尾,并且许多流行的原语位于 lax
中。
inverse_registry[lax.exp_p] = jnp.log
inverse_registry[lax.tanh_p] = jnp.arctanh
inverse
将首先追踪函数,然后自定义解释Jaxpr。让我们设置一个简单的框架。
def inverse(fun):
@wraps(fun)
def wrapped(*args, **kwargs):
# 由于我们假设是一元函数,因此不必担心扁平化问题。
# 解平化论证。
closed_jaxpr = jax.make_jaxpr(fun)(*args, **kwargs)
out = inverse_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, *args)
return out[0]
return wrapped
现在我们只需要定义 inverse_jaxpr
,它将向后遍历 Jaxpr,并在可能的情况下反转原语。
def inverse_jaxpr(jaxpr, consts, *args):
env = {}
def read(var):
if type(var) is core.Literal:
return var.val
return env[var]
def write(var, val):
env[var] = val
# 现在,参数对应于Jaxpr的输出变量。
safe_map(write, jaxpr.outvars, args)
safe_map(write, jaxpr.constvars, consts)
# 回溯循环
for eqn in jaxpr.eqns[::-1]:
# 输出变量现在成为了输入变量
invals = safe_map(read, eqn.outvars)
if eqn.primitive not in inverse_registry:
raise NotImplementedError(
f"{eqn.primitive} does not have registered inverse.")
# 假设一个一元函数
outval = inverse_registry[eqn.primitive](*invals)
safe_map(write, eqn.invars, [outval])
return safe_map(read, jaxpr.invars)
好的!请提供您希望翻译的ipynb文件内容。
def f(x):
return jnp.exp(jnp.tanh(x))
f_inv = inverse(f)
assert jnp.allclose(f_inv(f(1.0)), 1.0)
重要的是,您可以通过Jaxpr解释器进行跟踪。
jax.make_jaxpr(inverse(f))(f(1.))
{ lambda ; a:f32[]. let b:f32[] = log a; c:f32[] = atanh b in (c,) }
添加新转换到系统中只需要这些,你可以免费获得与所有其他转换的组合!例如,我们可以将 jit
、vmap
和 grad
与 inverse
一起使用!
jit(vmap(grad(inverse(f))))((jnp.arange(5) + 1.) / 5.)
Array([-3.1440797, 15.584931 , 2.2551253, 1.3155028, 1. ], dtype=float32, weak_type=True)
读者练习#
处理具有多个参数的原始函数,当输入部分已知时,例如
lax.add_p
、lax.mul_p
。处理
xla_call
和xla_pmap
原始函数,这些函数在当前形式下无法与eval_jaxpr
和inverse_jaxpr
一起使用。