外部回调#
本指南概述了各种回调函数的使用,这些回调函数允许JAX运行时在主机上执行Python代码,即使在jit
、vmap
、grad
或其他变换下运行时。
为什么使用回调?#
回调例程是一种在运行时执行主机端代码的方法。
作为一个简单的例子,假设您想在计算过程中打印某个变量的值。
使用简单的Python print
语句,代码如下:
import jax
@jax.jit
def f(x):
y = x + 1
print("intermediate value: {}".format(y))
return y * 2
result = f(2)
intermediate value: Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
打印的不是运行时值,而是跟踪时抽象值(如果您不熟悉 JAX 中的 跟踪,可以在 How To Think In JAX 找到一个很好的入门)。
要在运行时打印值,我们需要一个回调,例如 jax.debug.print
:
@jax.jit
def f(x):
y = x + 1
jax.debug.print("intermediate value: {}", y)
return y * 2
result = f(2)
intermediate value: 3
这通过将由 y
表示的运行时值传回主机进程来实现,主机可以打印该值。
回调的类型#
在早期版本的 JAX 中,只有一种回调可用,来自 jax.experimental.host_callback
。host_callback
的例程存在一些不足之处,现在已经被多个针对不同场景设计的回调所取代:
jax.pure_callback()
:适合纯函数;即没有副作用的函数。jax.experimental.io_callback()
:适合不纯函数;例如,读取或写入磁盘数据的函数。jax.debug.callback()
:适合需要反映编译器执行行为的函数。
(我们上面使用的 jax.debug.print()
函数是 jax.debug.callback()
的一个包装)。
从用户的角度来看,这三种回调主要通过它们允许的变换和编译器优化进行区分。
回调函数 |
支持返回值 |
|
|
|
|
保证执行 |
---|---|---|---|---|---|---|
|
✅ |
✅ |
✅ |
❌¹ |
✅ |
❌ |
|
✅ |
✅ |
✅/❌² |
❌ |
✅³ |
✅ |
|
❌ |
✅ |
✅ |
✅ |
✅ |
❌ |
¹ jax.pure_callback
可以与 custom_jvp
一起使用,以使其与自动微分兼容。
² jax.experimental.io_callback
仅当 ordered=False
时与 vmap
兼容。
³ 请注意,io_callback
的 scan
/while_loop
的 vmap
具有复杂的语义,其行为可能在未来的版本中发生变化。
探索 jax.pure_callback
#
jax.pure_callback
通常是你在想要纯函数的主机端执行时应该使用的回调函数:即没有副作用的函数(例如打印值、从磁盘读取数据、更新全局状态等)。
你传递给 jax.pure_callback
的函数不需要实际上是纯的,但它在 JAX 的变换和高阶函数中会被假设为纯的,这意味着它可能会被默默省略或多次调用。
import jax
import jax.numpy as jnp
import numpy as np
def f_host(x):
# 调用一个 numpy(非 jax.numpy)操作:
return np.sin(x).astype(x.dtype)
def f(x):
result_shape = jax.ShapeDtypeStruct(x.shape, x.dtype)
return jax.pure_callback(f_host, result_shape, x)
x = jnp.arange(5.0)
f(x)
Array([ 0. , 0.841471 , 0.9092974, 0.14112 , -0.7568025], dtype=float32)
因为 pure_callback
可以被省略或重复,它可以开箱即用地与 jit
和 vmap
等变换,以及 scan
和 while_loop
等高阶原语兼容:
jax.jit(f)(x)
Array([ 0. , 0.841471 , 0.9092974, 0.14112 , -0.7568025], dtype=float32)
jax.vmap(f)(x)
Array([ 0. , 0.841471 , 0.9092974, 0.14112 , -0.7568025], dtype=float32)
def body_fun(_, x):
return _, f(x)
jax.lax.scan(body_fun, None, jnp.arange(5.0))[1]
Array([ 0. , 0.841471 , 0.9092974, 0.14112 , -0.7568025], dtype=float32)
然而,由于JAX无法检测回调的内容,因此pure_callback
具有未定义的自动微分语义:
%xmode minimal
Exception reporting mode: Minimal
jax.grad(f)(x)
ValueError: Pure callbacks do not support JVP. Please use `jax.custom_jvp` to use callbacks while taking gradients.
有关使用 pure_callback
和 jax.custom_jvp
的示例,请参见下面的 示例:pure_callback
与 custom_jvp
。
根据设计,传递给 pure_callback
的函数被视为没有副作用:这导致的一个后果是,如果函数的输出未被使用,编译器可能会完全消除该回调。
def print_something():
print('printing something')
return np.int32(0)
@jax.jit
def f1():
return jax.pure_callback(print_something, np.int32(0))
f1();
printing something
@jax.jit
def f2():
jax.pure_callback(print_something, np.int32(0))
return 1.0
f2();
在f1
中,回调的输出被用作函数的返回值,因此回调被执行,我们可以看到打印的输出。另一方面,在f2
中,回调的输出未被使用,因此编译器注意到了这一点并消除了函数调用。这对于没有副作用的函数的回调来说是正确的语义。
探索 jax.experimental.io_callback
#
与 jax.pure_callback()
相比,jax.experimental.io_callback()
明确用于不纯函数,即具有副作用的函数。
作为示例,以下是一个对全局主机端 numpy 随机生成器的回调。这是一个不纯操作,因为在 numpy 中生成随机数的副作用是随机状态会被更新(请注意,这只是 io_callback
的一个玩具示例,并不一定是 JAX 中生成随机数的推荐方法!)。
from jax.experimental import io_callback
from functools import partial
global_rng = np.random.default_rng(0)
def host_side_random_like(x):
"""使用global_rng状态生成一个随机数组,类似于x"""
# 我们这里有两个副作用:
# - 打印形状和数据类型
# - 调用 global_rng,从而更新其状态
print(f'generating {x.dtype}{list(x.shape)}')
return global_rng.uniform(size=x.shape).astype(x.dtype)
@jax.jit
def numpy_random_like(x):
return io_callback(host_side_random_like, x, x)
x = jnp.zeros(5)
numpy_random_like(x)
generating float32[5]
Array([0.6369617 , 0.26978672, 0.04097353, 0.01652764, 0.8132702 ], dtype=float32)
io_callback
默认与 vmap
兼容:
jax.vmap(numpy_random_like)(x)
generating float32[]
generating float32[]
generating float32[]
generating float32[]
generating float32[]
Array([0.91275555, 0.60663575, 0.72949654, 0.543625 , 0.9350724 ], dtype=float32)
请注意,这可能会以任何顺序执行映射的回调。因此,例如,如果您在 GPU 上运行此操作,映射输出的顺序可能会因运行而异。
如果保持回调顺序很重要,您可以设置 ordered=True
,在这种情况下,尝试 vmap
将会引发错误:
@jax.jit
def numpy_random_like_ordered(x):
return io_callback(host_side_random_like, x, x, ordered=True)
jax.vmap(numpy_random_like_ordered)(x)
ValueError: Cannot `vmap` ordered IO callback.
另一方面,scan
和 while_loop
在使用 io_callback
时,不论是否强制顺序都是有效的:
def body_fun(_, x):
return _, numpy_random_like_ordered(x)
jax.lax.scan(body_fun, None, jnp.arange(5.0))[1]
generating float32[]
generating float32[]
generating float32[]
generating float32[]
generating float32[]
Array([0.81585354, 0.0027385 , 0.8574043 , 0.03358557, 0.72965544], dtype=float32)
像 pure_callback
一样,如果将一个被微分的变量传递给 io_callback
,它在自动微分中会失败:
jax.grad(numpy_random_like)(x)
ValueError: IO callbacks do not support JVP.
然而,如果回调函数不依赖于一个分化变量,它将会执行:
@jax.jit
def f(x):
io_callback(lambda: print('hello'), None)
return x
jax.grad(f)(1.0);
hello
与pure_callback
不同,编译器在这种情况下不会移除回调执行,即使回调的输出在后续计算中未被使用。
探索 debug.callback
#
pure_callback
和 io_callback
都对它们调用的函数的纯度施加了一些假设,并在各种方面限制了 JAX 转换和编译机制的行为。debug.callback
本质上对回调函数 没有 假设,因此回调的操作准确反映了 JAX 在程序执行过程中的行为。此外,debug.callback
不能 向程序返回任何值。
from jax import debug
def log_value(x):
# 这可能是一个实际的日志调用;我们将使用
# 用于演示的print()
print("log:", x)
@jax.jit
def f(x):
debug.callback(log_value, x)
return x
f(1.0);
log: 1.0
调试回调与vmap
兼容:
x = jnp.arange(5.0)
jax.vmap(f)(x);
log: 0.0
log: 1.0
log: 2.0
log: 3.0
log: 4.0
并且还与 grad
和其他自动微分变换兼容
jax.grad(f)(1.0)
log: 1.0
Array(1., dtype=float32, weak_type=True)
这可以使 debug.callback
对于通用调试比 pure_callback
或 io_callback
更加实用。
示例:结合 custom_jvp
的 pure_callback
#
利用 jax.pure_callback()
的一种强大方法是将其与 jax.custom_jvp
结合起来(有关 custom_jvp
的更多详细信息,请参见 自定义导数规则)。
假设我们想为尚未在 jax.scipy
或 jax.numpy
包装器中提供的 scipy 或 numpy 函数创建一个兼容 JAX 的包装器。
在这里,我们将考虑为第一类贝塞尔函数创建一个包装器,该函数在 scipy.special.jv
中实现。
我们可以通过定义一个简单的 pure_callback
开始:
import jax
import jax.numpy as jnp
import scipy.special
def jv(v, z):
v, z = jnp.asarray(v), jnp.asarray(z)
# 要求变量 v 为整数类型:这简化了
# 请参阅下面的JVP规则。
assert jnp.issubdtype(v.dtype, jnp.integer)
# 将输入提升为不精确(浮点数/复数)。
# 请注意,jnp.result_type() 会考虑 enable_x64 标志的影响。
z = z.astype(jnp.result_type(float, z.dtype))
# 将scipy函数封装以返回预期的数据类型。
_scipy_jv = lambda v, z: scipy.special.jv(v, z).astype(z.dtype)
# 定义输出的预期形状和数据类型。
result_shape_dtype = jax.ShapeDtypeStruct(
shape=jnp.broadcast_shapes(v.shape, z.shape),
dtype=z.dtype)
# 我们使用vectorize=True,因为scipy.special.jv能够处理广播输入。
return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)
这使我们可以从经过变换的 JAX 代码中调用 scipy.special.jv
,包括经过 jit
和 vmap
变换时:
j1 = partial(jv, 1)
z = jnp.arange(5.0)
print(j1(z))
[ 0. 0.44005057 0.5767248 0.33905897 -0.06604332]
这是使用 jit
得到的相同结果:
print(jax.jit(j1)(z))
[ 0. 0.44005057 0.5767248 0.33905897 -0.06604332]
这里是使用 vmap
得到的相同结果:
print(jax.vmap(j1)(z))
[ 0. 0.44005057 0.5767248 0.33905897 -0.06604332]
然而,如果我们调用 jax.grad
,我们会看到一个错误,因为这个函数没有定义自动微分规则:
jax.grad(j1)(z)
ValueError: Pure callbacks do not support JVP. Please use `jax.custom_jvp` to use callbacks while taking gradients.
让我们为此定义一个自定义的梯度规则。查看Bessel第一类函数的定义,我们发现关于自变量z
的导数有一个相对简单的递归关系:
根据给定的递归关系,我们可以修正关于 Bessel 第一类函数 \(J_\nu(z)\) 的导数公式。这个公式给出了 \(J_\nu(z)\) 对于自变量 \(z\) 的导数。实际的导数公式应该如下:
关于\(\nu\)的梯度更为复杂,但由于我们将v
参数限制为整数类型,因此在这个例子中我们不需要担心它的梯度。
我们可以使用jax.custom_jvp
来为我们的回调函数定义这个自动微分规则:
jv = jax.custom_jvp(jv)
@jv.defjvp
def _jv_jvp(primals, tangents):
v, z = primals
_, z_dot = tangents # 注意:v_dot始终为0,因为v是整数。
jv_minus_1, jv_plus_1 = jv(v - 1, z), jv(v + 1, z)
djv_dz = jnp.where(v == 0, -jv_plus_1, 0.5 * (jv_minus_1 - jv_plus_1))
return jv(v, z), z_dot * djv_dz
现在计算我们函数的梯度将正常工作:
j1 = partial(jv, 1)
print(jax.grad(j1)(2.0))
-0.06447162
进一步来说,由于我们已经根据 jv
自身定义了我们的梯度,JAX 的架构意味着我们可以免费获得二阶及更高阶的导数:
jax.hessian(j1)(2.0)
Array(-0.4003078, dtype=float32, weak_type=True)
请注意,虽然这一切在JAX中工作正常,但每次调用我们的基于回调的jv
函数时,都将导致输入数据从设备传递到主机,并将scipy.special.jv
的输出从主机传回设备。
在像GPU或TPU这样的加速器上运行时,这种数据移动和主机同步可能导致每次调用jv
时产生显著的开销。
然而,如果您在单个CPU上运行JAX(“主机”和“设备”位于同一硬件上),JAX通常会以快速的零复制方式进行数据传输,这使得这种模式相对简单,能够扩展JAX的功能。