在JAX中编写自定义Jaxpr解释器#

在Colab中打开 在Kaggle中打开

JAX 提供了多个可组合的函数转换(jit, grad, vmap 等),使得编写简洁、高效的代码成为可能。

在这里,我们展示如何通过编写自定义的 Jaxpr 解释器将自己的函数转换添加到系统中。这样,我们就能免费获得与其他转换的可组合性。

此示例使用了内部 JAX API,可能随时会中断。未在 API 文档 中的内容应假定为内部使用。

import jax
import jax.numpy as jnp
from jax import jit, grad, vmap
from jax import random

JAX在做什么?#

JAX提供了类似NumPy的API用于数值计算,可以直接使用,但JAX的真正力量来自于可组合的函数转换。以jit函数转换为例,它接收一个函数并返回一个语义上相同的函数,但该函数是由XLA为加速器延迟编译的。

x = random.normal(random.key(0), (5000, 5000))
def f(w, b, x):
  return jnp.tanh(jnp.dot(x, w) + b)
fast_f = jit(f)

当我们调用 fast_f 时,会发生什么?JAX 追踪该函数并构建一个 XLA 计算图。然后,图会被 JIT 编译并执行。其他变换的工作方式相似,它们首先追踪该函数并以某种方式处理输出跟踪。要了解更多关于 JAX 的追踪机制的信息,可以查看 README 中的 “它是如何工作的” 部分。

Jaxpr 跟踪器#

在 Jax 中,一个特别重要的跟踪器是 Jaxpr 跟踪器,它将操作记录到 Jaxpr(Jax 表达式)中。 Jaxpr 是一种数据结构,可以像迷你函数式编程语言一样进行求值,因此 Jaxpr 是函数转换的一个有用的中间表示。

要初步了解Jaxprs,请考虑make_jaxpr转换。 make_jaxpr本质上是一种“美化打印”转换:它将一个函数转换为一个函数,该函数在给定示例参数时产生其计算的Jaxpr表示。 make_jaxpr对于调试和反思非常有用。 让我们用它来看看一些示例Jaxprs是如何构造的。

def examine_jaxpr(closed_jaxpr):
  jaxpr = closed_jaxpr.jaxpr
  print("invars:", jaxpr.invars)
  print("outvars:", jaxpr.outvars)
  print("constvars:", jaxpr.constvars)
  for eqn in jaxpr.eqns:
    print("equation:", eqn.invars, eqn.primitive, eqn.outvars, eqn.params)
  print()
  print("jaxpr:", jaxpr)

def foo(x):
  return x + 1
print("foo")
print("=====")
examine_jaxpr(jax.make_jaxpr(foo)(5))

print()

def bar(w, b, x):
  return jnp.dot(w, x) + b + jnp.ones(5), x
print("bar")
print("=====")
examine_jaxpr(jax.make_jaxpr(bar)(jnp.ones((5, 10)), jnp.ones(5), jnp.ones(10)))
foo
=====
invars: [Var(id=4464342144):int32[]]
outvars: [Var(id=4464337536):int32[]]
constvars: []
equation: [Var(id=4464342144):int32[], 1] add [Var(id=4464337536):int32[]] {}

jaxpr: { lambda ; a:i32[]. let b:i32[] = add a 1 in (b,) }

bar
=====
invars: [Var(id=4464599616):float32[5,10], Var(id=4464599936):float32[5], Var(id=4464600000):float32[10]]
outvars: [Var(id=4464600256):float32[5], Var(id=4464600000):float32[10]]
constvars: []
equation: [Var(id=4464599616):float32[5,10], Var(id=4464600000):float32[10]] dot_general [Var(id=4464600064):float32[5]] {'dimension_numbers': (((1,), (0,)), ((), ())), 'precision': None, 'preferred_element_type': dtype('float32')}
equation: [Var(id=4464600064):float32[5], Var(id=4464599936):float32[5]] add [Var(id=4464600128):float32[5]] {}
equation: [1.0] broadcast_in_dim [Var(id=4464600192):float32[5]] {'shape': (5,), 'broadcast_dimensions': ()}
equation: [Var(id=4464600128):float32[5], Var(id=4464600192):float32[5]] add [Var(id=4464600256):float32[5]] {}

jaxpr: { lambda ; a:f32[5,10] b:f32[5] c:f32[10]. let
    d:f32[5] = dot_general[
      dimension_numbers=(([1], [0]), ([], []))
      preferred_element_type=float32
    ] a c
    e:f32[5] = add d b
    f:f32[5] = broadcast_in_dim[broadcast_dimensions=() shape=(5,)] 1.0
    g:f32[5] = add e f
  in (g, c) }
  • jaxpr.invars - Jaxpr 的 invars 是输入变量的列表,类似于 Python 函数中的参数。

  • jaxpr.outvars - Jaxpr 的 outvars 是 Jaxpr 返回的变量。每个 Jaxpr 有多个输出。

  • jaxpr.constvars - constvars 是一组变量,它们也是 Jaxpr 的输入,但对应于跟踪中的常量(我们稍后会详细讲解这些)。

  • jaxpr.eqns - 方程列表,本质上是让绑定。每个方程都是输入变量列表、输出变量列表和一个用于评估输入以产生输出的 原语。每个方程还有一个 params,它是参数的字典。

总之,Jaxpr 封装了一个简单的程序,可以用输入进行评估以生成输出。我们稍后会详细介绍如何做到这一点。现在需要注意的重要事项是,Jaxpr 是一个可以以我们想要的方式进行操作和评估的数据结构。

Jaxprs 为什么有用?#

Jaxprs 是简单的程序表示,易于转换。由于 Jax 允许我们将 Jaxprs 从 Python 函数中分离出来,这为我们提供了一种转换用 Python 编写的数值程序的方法。

你的第一个解释器:invert#

让我们尝试实现一个简单的函数“inverter”,该函数接受原始函数的输出并返回产生这些输出的输入。现在,让我们专注于由其他可逆的单一函数组成的简单单一函数。

目标:

def f(x):
  return jnp.exp(jnp.tanh(x))
f_inv = inverse(f)
assert jnp.allclose(f_inv(f(1.0)), 1.0)

我们将通过以下方式实现这一点:(1)将 f 跟踪到 Jaxpr 中,然后(2)反向解释 Jaxpr。在反向解释 Jaxpr 时,对于每个方程,我们将查找原语的逆在一个表中并应用它。

1. 跟踪一个函数#

让我们使用 make_jaxpr 将一个函数跟踪到 Jaxpr 中。

# 导入用于追踪/解释的Jax函数。
from functools import wraps

from jax import core
from jax import lax
from jax._src.util import safe_map

jax.make_jaxpr 返回一个闭合 Jaxpr,即已经与跟踪中的常量(literals)捆绑在一起的 Jaxpr。

def f(x):
  return jnp.exp(jnp.tanh(x))

closed_jaxpr = jax.make_jaxpr(f)(jnp.ones(5))
print(closed_jaxpr.jaxpr)
print(closed_jaxpr.literals)
{ lambda ; a:f32[5]. let b:f32[5] = tanh a; c:f32[5] = exp b in (c,) }
[]

2. 评估 Jaxpr#

在我们编写自定义 Jaxpr 解释器之前,首先实现 “默认” 解释器 eval_jaxpr,该解释器将按原样评估 Jaxpr,计算与原始未转换的 Python 函数相同的值。

为此,我们首先创建一个环境来存储每个变量的值,并在评估 Jaxpr 中的每个方程时更新该环境。

def eval_jaxpr(jaxpr, consts, *args):
  # 从变量到值的映射
  env = {}

  def read(var):
    # 字面量是嵌入在 Jaxpr 中的值
    if type(var) is core.Literal:
      return var.val
    return env[var]

  def write(var, val):
    env[var] = val

  # 将参数和常量绑定到环境中
  safe_map(write, jaxpr.invars, args)
  safe_map(write, jaxpr.constvars, consts)

  # 遍历方程并使用 `bind` 评估原语
  for eqn in jaxpr.eqns:
    # 从环境中读取方程的输入
    invals = safe_map(read, eqn.invars)
    # `bind` 是原始调用的方式
    outvals = eqn.primitive.bind(*invals, **eqn.params)
    # 原语可能会返回多个输出,也可能不会。
    if not eqn.primitive.multiple_results:
      outvals = [outvals]
    # 将原始结果写入环境
    safe_map(write, eqn.outvars, outvals)
  # 从环境中读取Jaxpr的最终结果
  return safe_map(read, jaxpr.outvars)
closed_jaxpr = jax.make_jaxpr(f)(jnp.ones(5))
eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, jnp.ones(5))
[Array([2.1416876, 2.1416876, 2.1416876, 2.1416876, 2.1416876], dtype=float32)]

请注意,eval_jaxpr 将始终返回一个扁平列表,即使原始函数并非如此。

此外,这个解释器不处理高阶原语(例如 jitpmap),这些我们在本指南中将不予以讨论。您可以参考 core.eval_jaxpr链接)以查看该解释器未涵盖的边缘案例。

自定义 inverse Jaxpr 解释器#

一个 inverse 解释器与 eval_jaxpr 看起来没有太大区别。我们首先设置一个注册表,将原语映射到它们的逆操作。然后我们将编写一个自定义解释器,该解释器在注册表中查找原语。

事实证明,这个解释器也将类似于在反向模式自动微分中使用的“转置”解释器 链接在这里

inverse_registry = {}

我们现在将为一些原语注册逆运算。根据惯例,Jax 中的原语以 _p 结尾,并且许多流行的原语位于 lax 中。

inverse_registry[lax.exp_p] = jnp.log
inverse_registry[lax.tanh_p] = jnp.arctanh

inverse将首先追踪函数,然后自定义解释Jaxpr。让我们设置一个简单的框架。

def inverse(fun):
  @wraps(fun)
  def wrapped(*args, **kwargs):
    # 由于我们假设是一元函数,因此不必担心扁平化问题。
    # 解平化论证。
    closed_jaxpr = jax.make_jaxpr(fun)(*args, **kwargs)
    out = inverse_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, *args)
    return out[0]
  return wrapped

现在我们只需要定义 inverse_jaxpr,它将向后遍历 Jaxpr,并在可能的情况下反转原语。

def inverse_jaxpr(jaxpr, consts, *args):
  env = {}

  def read(var):
    if type(var) is core.Literal:
      return var.val
    return env[var]

  def write(var, val):
    env[var] = val
  # 现在,参数对应于Jaxpr的输出变量。
  safe_map(write, jaxpr.outvars, args)
  safe_map(write, jaxpr.constvars, consts)

  # 回溯循环
  for eqn in jaxpr.eqns[::-1]:
    # 输出变量现在成为了输入变量
    invals = safe_map(read, eqn.outvars)
    if eqn.primitive not in inverse_registry:
      raise NotImplementedError(
          f"{eqn.primitive} does not have registered inverse.")
    # 假设一个一元函数
    outval = inverse_registry[eqn.primitive](*invals)
    safe_map(write, eqn.invars, [outval])
  return safe_map(read, jaxpr.invars)

好的!请提供您希望翻译的ipynb文件内容。

def f(x):
  return jnp.exp(jnp.tanh(x))

f_inv = inverse(f)
assert jnp.allclose(f_inv(f(1.0)), 1.0)

重要的是,您可以通过Jaxpr解释器进行跟踪。

jax.make_jaxpr(inverse(f))(f(1.))
{ lambda ; a:f32[]. let b:f32[] = log a; c:f32[] = atanh b in (c,) }

添加新转换到系统中只需要这些,你可以免费获得与所有其他转换的组合!例如,我们可以将 jitvmapgradinverse 一起使用!

jit(vmap(grad(inverse(f))))((jnp.arange(5) + 1.) / 5.)
Array([-3.1440797, 15.584931 ,  2.2551253,  1.3155028,  1.       ],      dtype=float32, weak_type=True)

读者练习#

  • 处理具有多个参数的原始函数,当输入部分已知时,例如 lax.add_plax.mul_p

  • 处理 xla_callxla_pmap 原始函数,这些函数在当前形式下无法与 eval_jaxprinverse_jaxpr 一起使用。