jax.experimental.host_callback
模块#
从 JAX 加速器代码调用主机上 Python 函数的基本组件。
警告
自2024年3月20日起,host_callback API 已被弃用。该功能已被 新的 JAX 外部回调 所取代。详情请参见 google/jax#20385。
此模块介绍了主机回调函数 call()
、id_tap()
和 id_print()
,它们将参数从设备发送到主机,并在主机上调用用户定义的Python函数,可以选择将结果返回给设备计算。
我们下面展示这些函数如何使用。我们从 call()
开始,并讨论从 JAX 调用任意 Python 函数到 CPU 的示例,例如,使用 NumPy CPU 自定义内核。然后我们展示 id_tap()
和 id_print()
的使用,它们有一个限制,即不能从主机返回值到设备。这些原语通常更快,因为它们与设备代码异步执行。特别是,它们可以用于进入和调试 JAX 代码。
使用 call()
调用主机函数并将结果返回给设备#
使用 call()
在主机上调用计算,并将 NumPy 数组返回给设备计算。主机计算是有用的,例如,当设备计算需要一些需要在主机上进行 I/O 的数据,或者它需要一个在主机上可用但在 JAX 中不想编码的库时。例如,JAX 中一般矩阵的特征分解在 TPU 上不起作用。我们可以使用主机计算调用任何 JAX 加速器计算中的 Numpy 实现:
# This function runs on the host
def host_eig(m: np.ndarray) -> np.ndarray:
return np.linalg.eigvals(m)
# This function is used in JAX
def device_fun(m):
# We send "m" to the host, asking it to call "host_eig" and return the result.
# We have to specify the result shape and dtype, either in the form of an
# example return value or any object that has `shape` and `dtype` attributes,
# e.g., a NumPy array or a `jax.ShapeDtypeStruct`.
return hcb.call(host_eig, m,
# Given an input of shape (..., d, d), eig output has shape (..., d)
result_shape=jax.ShapeDtypeStruct(m.shape[:-1], m.dtype))
The call()
函数和 Python 主机函数都接受一个参数并返回一个结果,但这些可以是 pytrees。注意,我们必须使用 result_shape
关键字参数告诉 call()
从主机调用中期望的形状和数据类型。这一点很重要,因为设备代码是根据这些期望编译的。如果实际调用产生不同的结果形状,将在运行时引发错误。通常,此类错误以及主机计算引发的异常可能难以调试。请参阅下面的调试部分。这是 call()
的问题,但不是 id_tap()
的问题,因为后者设备代码不期望返回值。
可以在 jit 或 pmap 计算内部,或者在 cond/scan/while 控制流内部使用 call()
API。当在 jax.pmap()
内部使用时,每个参与的设备都会分别向主机发起调用:
def host_sin(x, *, device):
# The ``device`` argument is passed due to ``call_with_device=True`` below.
print(f"Invoking host_sin with {x.shape} on {device}")
return np.sin(x)
# Use pmap to run the computation on two devices
jax.pmap(lambda x: hcb.call(host_sin, x,
result_shape=x,
# Ask that the `host_sin` function be passed `device=dev`
call_with_device=True))(
np.ones((2, 4), dtype=np.float32))
# prints (in arbitrary order)
# Invoking host_sin with (4,) on cpu:0
# Invoking host_sin with (4,) on cpu:1
注意 call()
不支持任何 JAX 变换,但如我们下面所示,可以利用 JAX 中的自定义微分 的现有支持。
使用 id_tap()
在主机上调用 Python 函数,不返回任何值#
The id_tap()
和 id_print()
是 call()
的特殊情况,当你只想获得Python回调的副作用时。这些函数的优势在于,一旦参数被发送到主机,设备计算可以在不等待Python回调返回的情况下继续进行。对于 id_tap()
,你可以指定要调用的Python回调,而 id_print()
使用一个内置回调,该回调将参数打印到主机的 stdout 上。传递给 id_tap()
的Python函数接受两个位置参数(从设备计算中提取的值以及一个 transforms
元组,如下所述)。可选地,该函数还可以传递一个关键字参数 device
,其中包含从中提取值的设备。
一些例子:
def host_func(arg, transforms):
...do something with arg...
# calls host_func(2x, []) on host
id_tap(host_func, 2 * x)
# calls host_func((2x, 3x), [])
id_tap(host_func, (2 * x, 3 * x)) # The argument can be a pytree
# calls host_func(2x, [], device=jax.devices()[0])
id_tap(host_func, 2 * x, tap_with_device=True) # Pass the device to the tap
# calls host_func(2x, [], what='activation')
id_tap(functools.partial(host_func, what='activation'), 2 * x)
# calls host_func(dict(x=x, y=y), what='data')
id_tap(lambda tap, transforms: host_func(tap, what='data'), dict(x=x, y=y))
上述示例都可以调整为使用 id_print()
,不同的是 id_print()
会在主机上打印位置参数,以及任何额外的 kwargs 和自动的 kwarg transforms
。
使用 barrier_wait()
等待直到所有回调执行完毕#
如果你的Python回调有副作用,你可能需要等到计算完成以确保副作用已被观察到。你可以为此目的使用 barrier_wait()
函数:
accumulator = []
def host_log(arg, transforms):
# We just record the arguments in a list
accumulator.append(arg)
def device_fun(x):
id_tap(host_log, x)
id_tap(host_log, 2. * x)
jax.jit(device_fun)(1.)
jax.jit(device_fun)(1.)
# At this point, we have started two computations, each with two
# taps, but they may not have yet executed.
barrier_wait()
# Now we know that all the computations started before `barrier_wait`
# on all devices, have finished, and all the callbacks have finished
# executing.
注意 barrier_wait()
将启动一个微小的计算,每个 jax.local_devices() 上都会有一次点击,并且会等待所有这些点击被接收。
使用 barrier_wait()
的替代方法是,如果所有回调都是 call()
,则只需等待计算结束:
accumulator = p[]
def host_log(arg):
# We just record the arguments in a list
accumulator.append(arg)
return 0. # return something
def device_fun(c):
y = call(host_log, x, result_shape=jax.ShapeDtypeStruct((), np.float32))
z = call(host_log, 2. * x, result_shape=jax.ShapeDtypeStruct((), np.float32))
return y + z # return something that uses both results
res1 = jax.jit(device_fun)(1.)
res2 = jax.jit(device_fun)(1.)
res1.block_until_ready()
res2.block_until_ready()
并行化变换下的行为#
在使用 jax.pmap()
的情况下,代码将在多个设备上运行,每个设备将独立地获取其值。使用 id_print()
或 id_tap()
的 tap_with_device
选项可能会很有帮助,这样你可以看到哪个设备发送了哪些数据:
jax.pmap(power3, devices=jax.local_devices()[:2])(np.array([3., 4.])
# device=cpu:0 what=x,x^2: (3., 9.) # from the first device
# device=cpu:1 what=x,x^2: (4., 16.) # from the second device
在使用 jax.pmap()
进行多主机多设备操作时,每个主机将从其所有本地设备接收回调,每个回调对应于每个设备的切片。对于 call()
,回调必须仅返回与相应设备相关的结果切片。
在使用实验性的 pjit.pjit()
时,代码将在多个设备上运行,处理输入的不同分片。当前的主机回调实现将确保单个设备收集并输出整个操作数,通过单个回调。回调函数应返回整个数组,该数组随后将以单个输入形式发送回发出输出的同一设备。然后,该设备负责将所需的分片发送给其他设备:
with jax.sharding.Mesh(jax.local_devices()[:2], ["d"]):
pjit.pjit(power3, in_shardings=(P("d"),),
out_shardings=(P("d"),))(np.array([3., 4.]))
# device=TPU:0 what=x,x^2: ( [3., 4.],
# [9., 16.] )
请注意,如果操作数在设备间分片,那么在其中一个设备上收集操作数可能会导致OOM(内存不足)。
在使用 pjit.pjit()
与多主机上的多个设备时,只有设备 0 对应的主机(相对于网格)会接收到回调,回调的参数是从所有主机上的所有参与设备收集的。对于 call()
,回调必须返回所有主机上所有设备的完整数组。
在 JAX 自动微分变换下的行为#
在使用 JAX 自动微分变换时,主机回调函数仅对原始值进行操作。考虑以下示例:
def power3(x):
y = x * x
# Print both 'x' and 'x^2'. Must pack as a tuple.
hcb.id_print((x, y), what="x,x^2")
return y * x
power3(3.)
# what: x,x^2 : (3., 9.)
(你可以在 host_callback_test.HostCallbackTapTest.test_tap_transforms 中看到这些示例的测试。)
当在 jax.jvp()
下使用时,将会有一个仅包含原始值的回调:
jax.jvp(power3, (3.,), (0.1,))
# what: x,x^2 : (3., 9.)
同样地,对于 jax.grad()
,我们只从前向计算中得到一个回调:
jax.grad(power3)(3.)
# what: x,x^2 : (3., 9.)
如果你想在执行 jax.jvp()
时调用切线的回调,你可以使用 custom_jvp。例如,你可以定义一个函数,除了它的 custom_jvp 会打印切线之外,它不做任何有趣的事情:
@jax.custom_jvp
def print_tangents(arg):
return None
@print_tangents.defjvp
def print_tangents_jvp(primals, tangents):
arg_dot, = tangents
hcb.id_print(arg_dot, what="tangents")
return primals, tangents
然后在你想要捕捉切线的位置使用这个函数:
def power3_with_tangents(x):
y = x * x
# Print both 'x' and 'x^2'. Must pack as a tuple.
hcb.id_print((x, y), what="x,x^2")
print_tangents((x, y))
return y * x
jax.jvp(power3_with_tangents, (3.,), (0.1,))
# what: x,x^2 : (3., 9.)
# what: tangents : (0.1, 0.6)
你可以在 jax.grad()
期间对余切做类似的事情。这次你必须小心,在计算的其余部分使用你想要获取其余切值的那些值。因此,我们让 print_cotangents
返回其参数:
@jax.custom_vjp
def print_cotangents(arg):
# Must return the argument for which we want the cotangent.
return arg
# f_fwd: a -> (b, residual)
def print_cotangents_fwd(arg):
return print_cotangents(arg), None
# f_bwd: (residual, CT b) -> [CT a]
def print_cotangents_bwd(residual, ct_b):
hcb.id_print(ct_b, what="cotangents", output_stream=testing_stream)
return ct_b,
print_cotangents.defvjp(print_cotangents_fwd, print_cotangents_bwd)
def power3_with_cotangents(x):
y = x * x
# Print both 'x' and 'x^2'. Must pack as a tuple.
hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream)
(x1, y1) = print_cotangents((x, y))
# Must use the output of print_cotangents
return y1 * x1
jax.grad(power3_with_cotangents)(3.)
# what: x,x^2 : (3., 9.)
# what: cotangents : (9., 3.)
如果你使用 ad_checkpoint.checkpoint()
来重新生成反向传播的残差,那么原始计算中的回调函数将被调用两次:
jax.grad(lambda x: power3(ad_checkpoint.checkpoint(power3)(x)))(3.)
# what: x,x^2 : (3., 9.)
# what: x,x^2 : (27., 729.)
# what: x,x^2 : (3., 9.)
回调的顺序是:内部 power3
的原始计算,外部 power3
的原始计算,以及内部 power3
的残差重构。
jax.vmap 下的行为#
主机回调函数 id_print()
和 id_tap()
支持向量化转换 jax.vmap()
。
对于 jax.vmap()
,回调函数的参数是批处理的,并且回调函数会传递一个额外的特殊 transforms
,其中包含一个转换描述符列表,形式为 ("batch", {"batch_dims": ...})
,其中 ...
表示被点击值的批处理维度(每个参数一个条目,` None` 表示被广播的参数)。
jax.vmap(power3)(np.array([2., 3.])) # 转换: [(‘batch’, {‘batch_dims’: (0, 0)})] 什么: x,x^2 : ([2., 3.], [4., 9.])
参见 id_tap()
、id_print()
和 call()
的文档。
更多使用示例,请参见 tests/host_callback_test.py。
使用 call()
调用 TensorFlow 函数,支持反向模式自动微分#
主机计算的另一种可能用途是调用为另一个框架(如 TensorFlow)编写的库。在这种情况下,通过使用 jax.custom_vjp()
机制,将主机回调的 JAX 自动微分支持推迟到 TensorFlow 的自动微分机制,变得有趣。
一旦理解了 JAX 自定义 VJP 和 TensorFlow 自动微分机制,这相对容易实现。如何实现这一点的代码显示在 host_callback_to_tf_test.py 中的 call_tf_full_ad
函数中。此示例还支持任意高阶微分。
请注意,如果你只是想从 JAX 调用 TensorFlow 函数,你也可以使用 jax2tf.call_tf 函数。
使用 call()
在另一个设备上调用 JAX 函数,支持反向模式自动微分#
我们能够使用主机计算来在另一个设备上调用JAX计算,这应该不足为奇。参数从加速器发送到主机,然后发送到JAX主机计算将运行的外部设备,然后结果被发送回原始加速器。
如何实现这一点的代码展示在 host_callback_test.py 中的 call_jax_other_device 函数
。
底层细节与调试#
主机回调函数将按照在设备上执行发送操作的顺序,为每个设备执行。
多个设备的宿主回调函数可能会交错执行。设备的数据由JAX运行时管理的独立线程接收(每个设备一个线程)。运行时维护一个可配置大小的缓冲区(参见标志 --jax_host_callback_max_queue_byte_size
)。当缓冲区满时,所有接收线程都会暂停,这最终会暂停设备上的计算。运行时为每个设备额外提供一个线程,用于使用接收到的数据调用Python用户函数。如果回调处理速度慢,可能会导致运行时缓冲区填满,最终在设备需要发送数据时暂停计算。有关外发接收器运行时机制的更多详细信息,请参见 运行时代码。
为了暂停执行,直到设备上已经开始计算的所有数据到达并被处理完毕,请使用 barrier_wait()
。
用户定义的回调函数抛出的异常会连同它们的堆栈跟踪一起被记录,但接收线程不会停止。相反,最后一个异常会被记录,如果任何一个tap函数中发生了异常,随后的 barrier_wait()
将会抛出 CallbackException
。此异常将包含最后遇到的异常的文本和堆栈跟踪。
对于必须将结果返回给调用源设备的回调函数,例如 call()
,会出现进一步的复杂性。与 TPU 设备相比,这在 CPU/GPU 设备上的处理方式有所不同。
在CPU/GPU设备上,为了避免设备计算因等待永远不会到达的结果而卡住,如果在回调处理过程中发生任何错误(无论是由用户代码本身引发的还是由于返回值与预期的return_shape不匹配),我们会向设备发送一个形状为``int8[12345]``的“假”结果。这将导致设备计算中止,因为接收到的数据与预期的不符。在CPU上,运行时将以一个独特的错误消息崩溃:
` 检查失败:buffer->length() == buffer_length (12345 vs. ...) `
在GPU上,失败更加用户友好,并且会以以下方式呈现给Python程序:
` RET_CHECK 失败 ... 输入源缓冲区形状 s8[12345] ... 不匹配 `
要调试这些消息的根本原因,请参阅调试部分。
在TPU设备上,目前对输入没有形状检查,因此我们在出现错误时采取了更安全的做法,不发送这个假结果。这意味着计算将会挂起,并且不会抛出异常(但回调函数中的任何异常仍会出现在日志中)。
当前的实现使用了XLA提供的出料机制。从某种意义上说,该机制本身相当原始,因为接收器必须确切地知道每个传入数据包的形状,以及预期有多少数据包。这使得在同一计算中使用多种类型的数据变得困难,并且在条件语句或非固定迭代次数的循环中实际上无法使用。此外,直接使用出料机制的代码无法被JAX转换。所有这些限制都通过主机回调函数得到了解决。这里引入的敲击API使得轻松地为多种用途共享出料机制成为可能,同时支持所有转换。
请注意,在使用主机回调函数后,您不能直接使用 lax.outfeed。如果您之后需要使用 lax.outfeed,您可能需要 停止输出接收器()
。
由于实际调用回调函数是从 C++ 接收器进行的,因此可能难以调试这些调用。特别是,堆栈跟踪将不包括调用代码。您可以使用标志 jax_host_callback_inline``(或环境变量 ``JAX_HOST_CALLBACK_INLINE
)来确保对回调的调用是内联的。这仅在调用位于暂存上下文(jit()
或控制流原语)之外时有效。
C++ 接收器 在第一次调用 id_tap()
时会自动启动。为了正确停止它,启动时会注册一个 atexit
处理程序,以调用 barrier_wait()
并使用日志名称“at_exit”。
有一些环境变量可以用来为 C++ outfeed 接收器后端 启用日志记录。
TF_CPP_MIN_LOG_LEVEL=0
: 将开启INFO日志记录,以下所有内容都需要此设置。
TF_CPP_MIN_VLOG_LEVEL=3
: 将使所有 VLOG 日志记录到级别 3 的行为像 INFO 日志一样。这可能太多了,但你会看到哪些模块正在记录相关信息,然后你可以选择从哪些模块进行日志记录。
TF_CPP_VMODULE=<模块名>=3
(模块名可以是C++或Python,不带扩展名)。
你还应该使用 --verbosity=2
标志,以便查看来自 Python 的日志。
例如,你可以尝试在 host_callback
模块中启用日志记录:TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=host_callback=3 python tests/host_callback_test.py --verbosity=2 HostCallbackIdTapTest.test_tap_jit_simple
如果你想在较低级别的实现模块中启用日志记录,请尝试:TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=outfeed_receiver=3,host_callback=3,outfeed_receiver_py=3,outfeed_thunk=3,infeed_thunk=3,cpu_transfer_manager=3,cpu_runtime=3,xfeed_manager=3,pjrt_client=3 python tests/host_callback_test.py --verbosity=2 HostCallbackIdTapTest.test_tap_jit_simple
(对于bazel测试,使用 –test_arg=–vmodule=…)
- 待办事项:
更多性能测试。
探索使用外部编译实现TPU。
探索使用 XLA CustomCall 进行 CPU 和 GPU 的实现。
API#
|
主机回调水龙头原语,类似于带有对 |
|
类似于带有打印功能的 |
|
向主机发出调用,并期待一个结果。 |
|
阻塞调用线程,直到所有当前的外流数据处理完毕。 |
信号表示某些回调函数发生了异常。 |