外部回调#

本指南概述了各种回调函数的使用,这些回调函数允许JAX运行时在主机上执行Python代码,即使在jitvmapgrad或其他变换下运行时。

为什么使用回调?#

回调例程是一种在运行时执行主机端代码的方法。 作为一个简单的例子,假设您想在计算过程中打印某个变量的。 使用简单的Python print语句,代码如下:

import jax

@jax.jit
def f(x):
  y = x + 1
  print("intermediate value: {}".format(y))
  return y * 2

result = f(2)
intermediate value: Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>

打印的不是运行时值,而是跟踪时抽象值(如果您不熟悉 JAX 中的 跟踪,可以在 How To Think In JAX 找到一个很好的入门)。

要在运行时打印值,我们需要一个回调,例如 jax.debug.print

@jax.jit
def f(x):
  y = x + 1
  jax.debug.print("intermediate value: {}", y)
  return y * 2

result = f(2)
intermediate value: 3

这通过将由 y 表示的运行时值传回主机进程来实现,主机可以打印该值。

回调的类型#

在早期版本的 JAX 中,只有一种回调可用,来自 jax.experimental.host_callbackhost_callback 的例程存在一些不足之处,现在已经被多个针对不同场景设计的回调所取代:

(我们上面使用的 jax.debug.print() 函数是 jax.debug.callback() 的一个包装)。

从用户的角度来看,这三种回调主要通过它们允许的变换和编译器优化进行区分。

回调函数

支持返回值

jit

vmap

grad

scan/while_loop

保证执行

jax.pure_callback

❌¹

jax.experimental.io_callback

✅/❌²

✅³

jax.debug.callback

¹ jax.pure_callback 可以与 custom_jvp 一起使用,以使其与自动微分兼容。

² jax.experimental.io_callback 仅当 ordered=False 时与 vmap 兼容。

³ 请注意,io_callbackscan/while_loopvmap 具有复杂的语义,其行为可能在未来的版本中发生变化。

探索 jax.pure_callback#

jax.pure_callback 通常是你在想要纯函数的主机端执行时应该使用的回调函数:即没有副作用的函数(例如打印值、从磁盘读取数据、更新全局状态等)。

你传递给 jax.pure_callback 的函数不需要实际上是纯的,但它在 JAX 的变换和高阶函数中会被假设为纯的,这意味着它可能会被默默省略或多次调用。

import jax
import jax.numpy as jnp
import numpy as np

def f_host(x):
  # 调用一个 numpy(非 jax.numpy)操作:
  return np.sin(x).astype(x.dtype)

def f(x):
  result_shape = jax.ShapeDtypeStruct(x.shape, x.dtype)
  return jax.pure_callback(f_host, result_shape, x)

x = jnp.arange(5.0)
f(x)
Array([ 0.       ,  0.841471 ,  0.9092974,  0.14112  , -0.7568025],      dtype=float32)

因为 pure_callback 可以被省略或重复,它可以开箱即用地与 jitvmap 等变换,以及 scanwhile_loop 等高阶原语兼容:

jax.jit(f)(x)
Array([ 0.       ,  0.841471 ,  0.9092974,  0.14112  , -0.7568025],      dtype=float32)
jax.vmap(f)(x)
Array([ 0.       ,  0.841471 ,  0.9092974,  0.14112  , -0.7568025],      dtype=float32)
def body_fun(_, x):
  return _, f(x)
jax.lax.scan(body_fun, None, jnp.arange(5.0))[1]
Array([ 0.       ,  0.841471 ,  0.9092974,  0.14112  , -0.7568025],      dtype=float32)

然而,由于JAX无法检测回调的内容,因此pure_callback具有未定义的自动微分语义:

%xmode minimal
Exception reporting mode: Minimal
jax.grad(f)(x)
ValueError: Pure callbacks do not support JVP. Please use `jax.custom_jvp` to use callbacks while taking gradients.

有关使用 pure_callbackjax.custom_jvp 的示例,请参见下面的 示例:pure_callbackcustom_jvp

根据设计,传递给 pure_callback 的函数被视为没有副作用:这导致的一个后果是,如果函数的输出未被使用,编译器可能会完全消除该回调。

def print_something():
  print('printing something')
  return np.int32(0)

@jax.jit
def f1():
  return jax.pure_callback(print_something, np.int32(0))
f1();
printing something
@jax.jit
def f2():
  jax.pure_callback(print_something, np.int32(0))
  return 1.0
f2();

f1中,回调的输出被用作函数的返回值,因此回调被执行,我们可以看到打印的输出。另一方面,在f2中,回调的输出未被使用,因此编译器注意到了这一点并消除了函数调用。这对于没有副作用的函数的回调来说是正确的语义。

探索 jax.experimental.io_callback#

jax.pure_callback() 相比,jax.experimental.io_callback() 明确用于不纯函数,即具有副作用的函数。

作为示例,以下是一个对全局主机端 numpy 随机生成器的回调。这是一个不纯操作,因为在 numpy 中生成随机数的副作用是随机状态会被更新(请注意,这只是 io_callback 的一个玩具示例,并不一定是 JAX 中生成随机数的推荐方法!)。

from jax.experimental import io_callback
from functools import partial

global_rng = np.random.default_rng(0)

def host_side_random_like(x):
  """使用global_rng状态生成一个随机数组,类似于x"""
  # 我们这里有两个副作用:
  # - 打印形状和数据类型
  # - 调用 global_rng,从而更新其状态
  print(f'generating {x.dtype}{list(x.shape)}')
  return global_rng.uniform(size=x.shape).astype(x.dtype)

@jax.jit
def numpy_random_like(x):
  return io_callback(host_side_random_like, x, x)

x = jnp.zeros(5)
numpy_random_like(x)
generating float32[5]
Array([0.6369617 , 0.26978672, 0.04097353, 0.01652764, 0.8132702 ],      dtype=float32)

io_callback 默认与 vmap 兼容:

jax.vmap(numpy_random_like)(x)
generating float32[]
generating float32[]
generating float32[]
generating float32[]
generating float32[]
Array([0.91275555, 0.60663575, 0.72949654, 0.543625  , 0.9350724 ],      dtype=float32)

请注意,这可能会以任何顺序执行映射的回调。因此,例如,如果您在 GPU 上运行此操作,映射输出的顺序可能会因运行而异。

如果保持回调顺序很重要,您可以设置 ordered=True,在这种情况下,尝试 vmap 将会引发错误:

@jax.jit
def numpy_random_like_ordered(x):
  return io_callback(host_side_random_like, x, x, ordered=True)

jax.vmap(numpy_random_like_ordered)(x)
ValueError: Cannot `vmap` ordered IO callback.

另一方面,scanwhile_loop 在使用 io_callback 时,不论是否强制顺序都是有效的:

def body_fun(_, x):
  return _, numpy_random_like_ordered(x)
jax.lax.scan(body_fun, None, jnp.arange(5.0))[1]
generating float32[]
generating float32[]
generating float32[]
generating float32[]
generating float32[]
Array([0.81585354, 0.0027385 , 0.8574043 , 0.03358557, 0.72965544],      dtype=float32)

pure_callback 一样,如果将一个被微分的变量传递给 io_callback,它在自动微分中会失败:

jax.grad(numpy_random_like)(x)
ValueError: IO callbacks do not support JVP.

然而,如果回调函数不依赖于一个分化变量,它将会执行:

@jax.jit
def f(x):
  io_callback(lambda: print('hello'), None)
  return x

jax.grad(f)(1.0);
hello

pure_callback不同,编译器在这种情况下不会移除回调执行,即使回调的输出在后续计算中未被使用。

探索 debug.callback#

pure_callbackio_callback 都对它们调用的函数的纯度施加了一些假设,并在各种方面限制了 JAX 转换和编译机制的行为。debug.callback 本质上对回调函数 没有 假设,因此回调的操作准确反映了 JAX 在程序执行过程中的行为。此外,debug.callback 不能 向程序返回任何值。

from jax import debug

def log_value(x):
  # 这可能是一个实际的日志调用;我们将使用
  # 用于演示的print()
  print("log:", x)

@jax.jit
def f(x):
  debug.callback(log_value, x)
  return x

f(1.0);
log: 1.0

调试回调与vmap兼容:

x = jnp.arange(5.0)
jax.vmap(f)(x);
log: 0.0
log: 1.0
log: 2.0
log: 3.0
log: 4.0

并且还与 grad 和其他自动微分变换兼容

jax.grad(f)(1.0)
log: 1.0
Array(1., dtype=float32, weak_type=True)

这可以使 debug.callback 对于通用调试比 pure_callbackio_callback 更加实用。

示例:结合 custom_jvppure_callback#

利用 jax.pure_callback() 的一种强大方法是将其与 jax.custom_jvp 结合起来(有关 custom_jvp 的更多详细信息,请参见 自定义导数规则)。 假设我们想为尚未在 jax.scipyjax.numpy 包装器中提供的 scipy 或 numpy 函数创建一个兼容 JAX 的包装器。

在这里,我们将考虑为第一类贝塞尔函数创建一个包装器,该函数在 scipy.special.jv 中实现。 我们可以通过定义一个简单的 pure_callback 开始:

import jax
import jax.numpy as jnp
import scipy.special

def jv(v, z):
  v, z = jnp.asarray(v), jnp.asarray(z)

  # 要求变量 v 为整数类型:这简化了
  # 请参阅下面的JVP规则。
  assert jnp.issubdtype(v.dtype, jnp.integer)

  # 将输入提升为不精确(浮点数/复数)。
  # 请注意,jnp.result_type() 会考虑 enable_x64 标志的影响。
  z = z.astype(jnp.result_type(float, z.dtype))

  # 将scipy函数封装以返回预期的数据类型。
  _scipy_jv = lambda v, z: scipy.special.jv(v, z).astype(z.dtype)

  # 定义输出的预期形状和数据类型。
  result_shape_dtype = jax.ShapeDtypeStruct(
      shape=jnp.broadcast_shapes(v.shape, z.shape),
      dtype=z.dtype)

  # 我们使用vectorize=True,因为scipy.special.jv能够处理广播输入。
  return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)

这使我们可以从经过变换的 JAX 代码中调用 scipy.special.jv,包括经过 jitvmap 变换时:

j1 = partial(jv, 1)
z = jnp.arange(5.0)
print(j1(z))
[ 0.          0.44005057  0.5767248   0.33905897 -0.06604332]

这是使用 jit 得到的相同结果:

print(jax.jit(j1)(z))
[ 0.          0.44005057  0.5767248   0.33905897 -0.06604332]

这里是使用 vmap 得到的相同结果:

print(jax.vmap(j1)(z))
[ 0.          0.44005057  0.5767248   0.33905897 -0.06604332]

然而,如果我们调用 jax.grad,我们会看到一个错误,因为这个函数没有定义自动微分规则:

jax.grad(j1)(z)
ValueError: Pure callbacks do not support JVP. Please use `jax.custom_jvp` to use callbacks while taking gradients.

让我们为此定义一个自定义的梯度规则。查看Bessel第一类函数的定义,我们发现关于自变量z的导数有一个相对简单的递归关系:

根据给定的递归关系,我们可以修正关于 Bessel 第一类函数 \(J_\nu(z)\) 的导数公式。这个公式给出了 \(J_\nu(z)\) 对于自变量 \(z\) 的导数。实际的导数公式应该如下:

\[\begin{split} \frac{d}{dz} J_\nu(z) = \left\{ \begin{array}{ll} -J_1(z), & \nu=0 \\[1em] \frac{1}{2} [J_{\nu - 1}(z) - J_{\nu + 1}(z)], & \nu \neq 0 \end{array} \right. \end{split}\]

关于\(\nu\)的梯度更为复杂,但由于我们将v参数限制为整数类型,因此在这个例子中我们不需要担心它的梯度。

我们可以使用jax.custom_jvp来为我们的回调函数定义这个自动微分规则:

jv = jax.custom_jvp(jv)

@jv.defjvp
def _jv_jvp(primals, tangents):
  v, z = primals
  _, z_dot = tangents  # 注意:v_dot始终为0,因为v是整数。
  jv_minus_1, jv_plus_1 = jv(v - 1, z), jv(v + 1, z)
  djv_dz = jnp.where(v == 0, -jv_plus_1, 0.5 * (jv_minus_1 - jv_plus_1))
  return jv(v, z), z_dot * djv_dz

现在计算我们函数的梯度将正常工作:

j1 = partial(jv, 1)
print(jax.grad(j1)(2.0))
-0.06447162

进一步来说,由于我们已经根据 jv 自身定义了我们的梯度,JAX 的架构意味着我们可以免费获得二阶及更高阶的导数:

jax.hessian(j1)(2.0)
Array(-0.4003078, dtype=float32, weak_type=True)

请注意,虽然这一切在JAX中工作正常,但每次调用我们的基于回调的jv函数时,都将导致输入数据从设备传递到主机,并将scipy.special.jv的输出从主机传回设备。 在像GPU或TPU这样的加速器上运行时,这种数据移动和主机同步可能导致每次调用jv时产生显著的开销。 然而,如果您在单个CPU上运行JAX(“主机”和“设备”位于同一硬件上),JAX通常会以快速的零复制方式进行数据传输,这使得这种模式相对简单,能够扩展JAX的功能。