jax.experimental.host_callback 模块#

从 JAX 加速器代码调用主机上 Python 函数的基本组件。

警告

自2024年3月20日起,host_callback API 已被弃用。该功能已被 新的 JAX 外部回调 所取代。详情请参见 google/jax#20385

此模块介绍了主机回调函数 call()id_tap()id_print(),它们将参数从设备发送到主机,并在主机上调用用户定义的Python函数,可以选择将结果返回给设备计算。

我们下面展示这些函数如何使用。我们从 call() 开始,并讨论从 JAX 调用任意 Python 函数到 CPU 的示例,例如,使用 NumPy CPU 自定义内核。然后我们展示 id_tap()id_print() 的使用,它们有一个限制,即不能从主机返回值到设备。这些原语通常更快,因为它们与设备代码异步执行。特别是,它们可以用于进入和调试 JAX 代码。

使用 call() 调用主机函数并将结果返回给设备#

使用 call() 在主机上调用计算,并将 NumPy 数组返回给设备计算。主机计算是有用的,例如,当设备计算需要一些需要在主机上进行 I/O 的数据,或者它需要一个在主机上可用但在 JAX 中不想编码的库时。例如,JAX 中一般矩阵的特征分解在 TPU 上不起作用。我们可以使用主机计算调用任何 JAX 加速器计算中的 Numpy 实现:

# This function runs on the host
def host_eig(m: np.ndarray) -> np.ndarray:
  return np.linalg.eigvals(m)

# This function is used in JAX
def device_fun(m):
  # We send "m" to the host, asking it to call "host_eig" and return the result.
  # We have to specify the result shape and dtype, either in the form of an
  # example return value or any object that has `shape` and `dtype` attributes,
  # e.g., a NumPy array or a `jax.ShapeDtypeStruct`.
  return hcb.call(host_eig, m,
                  # Given an input of shape (..., d, d), eig output has shape (..., d)
                  result_shape=jax.ShapeDtypeStruct(m.shape[:-1], m.dtype))

The call() 函数和 Python 主机函数都接受一个参数并返回一个结果,但这些可以是 pytrees。注意,我们必须使用 result_shape 关键字参数告诉 call() 从主机调用中期望的形状和数据类型。这一点很重要,因为设备代码是根据这些期望编译的。如果实际调用产生不同的结果形状,将在运行时引发错误。通常,此类错误以及主机计算引发的异常可能难以调试。请参阅下面的调试部分。这是 call() 的问题,但不是 id_tap() 的问题,因为后者设备代码不期望返回值。

可以在 jit 或 pmap 计算内部,或者在 cond/scan/while 控制流内部使用 call() API。当在 jax.pmap() 内部使用时,每个参与的设备都会分别向主机发起调用:

def host_sin(x, *, device):
  # The ``device`` argument is passed due to ``call_with_device=True`` below.
  print(f"Invoking host_sin with {x.shape} on {device}")
  return np.sin(x)

# Use pmap to run the computation on two devices
jax.pmap(lambda x: hcb.call(host_sin, x,
                            result_shape=x,
                            # Ask that the `host_sin` function be passed `device=dev`
                            call_with_device=True))(
         np.ones((2, 4), dtype=np.float32))

# prints (in arbitrary order)
# Invoking host_sin with (4,) on cpu:0
# Invoking host_sin with (4,) on cpu:1

注意 call() 不支持任何 JAX 变换,但如我们下面所示,可以利用 JAX 中的自定义微分 的现有支持。

使用 id_tap() 在主机上调用 Python 函数,不返回任何值#

The id_tap()id_print()call() 的特殊情况,当你只想获得Python回调的副作用时。这些函数的优势在于,一旦参数被发送到主机,设备计算可以在不等待Python回调返回的情况下继续进行。对于 id_tap(),你可以指定要调用的Python回调,而 id_print() 使用一个内置回调,该回调将参数打印到主机的 stdout 上。传递给 id_tap() 的Python函数接受两个位置参数(从设备计算中提取的值以及一个 transforms 元组,如下所述)。可选地,该函数还可以传递一个关键字参数 device,其中包含从中提取值的设备。

一些例子:

def host_func(arg, transforms):
   ...do something with arg...

# calls host_func(2x, []) on host
id_tap(host_func, 2 * x)

# calls host_func((2x, 3x), [])
id_tap(host_func, (2 * x, 3 * x))  # The argument can be a pytree

# calls host_func(2x, [], device=jax.devices()[0])
id_tap(host_func, 2 * x, tap_with_device=True)  # Pass the device to the tap

# calls host_func(2x, [], what='activation')
id_tap(functools.partial(host_func, what='activation'), 2 * x)

# calls host_func(dict(x=x, y=y), what='data')
id_tap(lambda tap, transforms: host_func(tap, what='data'), dict(x=x, y=y))

上述示例都可以调整为使用 id_print() ,不同的是 id_print() 会在主机上打印位置参数,以及任何额外的 kwargs 和自动的 kwarg transforms

使用 barrier_wait() 等待直到所有回调执行完毕#

如果你的Python回调有副作用,你可能需要等到计算完成以确保副作用已被观察到。你可以为此目的使用 barrier_wait() 函数:

accumulator = []
def host_log(arg, transforms):
  # We just record the arguments in a list
  accumulator.append(arg)


def device_fun(x):
  id_tap(host_log, x)
  id_tap(host_log, 2. * x)

jax.jit(device_fun)(1.)
jax.jit(device_fun)(1.)

# At this point, we have started two computations, each with two
# taps, but they may not have yet executed.
barrier_wait()
# Now we know that all the computations started before `barrier_wait`
# on all devices, have finished, and all the callbacks have finished
# executing.

注意 barrier_wait() 将启动一个微小的计算,每个 jax.local_devices() 上都会有一次点击,并且会等待所有这些点击被接收。

使用 barrier_wait() 的替代方法是,如果所有回调都是 call(),则只需等待计算结束:

accumulator = p[]
def host_log(arg):
  # We just record the arguments in a list
  accumulator.append(arg)
  return 0.  #  return something


def device_fun(c):
  y = call(host_log, x, result_shape=jax.ShapeDtypeStruct((), np.float32))
  z = call(host_log, 2. * x, result_shape=jax.ShapeDtypeStruct((), np.float32))
  return y + z  # return something that uses both results

res1 = jax.jit(device_fun)(1.)
res2 = jax.jit(device_fun)(1.)
res1.block_until_ready()
res2.block_until_ready()

并行化变换下的行为#

在使用 jax.pmap() 的情况下,代码将在多个设备上运行,每个设备将独立地获取其值。使用 id_print()id_tap()tap_with_device 选项可能会很有帮助,这样你可以看到哪个设备发送了哪些数据:

jax.pmap(power3, devices=jax.local_devices()[:2])(np.array([3., 4.])
# device=cpu:0 what=x,x^2: (3., 9.)  # from the first device
# device=cpu:1 what=x,x^2: (4., 16.)  # from the second device

在使用 jax.pmap() 进行多主机多设备操作时,每个主机将从其所有本地设备接收回调,每个回调对应于每个设备的切片。对于 call() ,回调必须仅返回与相应设备相关的结果切片。

在使用实验性的 pjit.pjit() 时,代码将在多个设备上运行,处理输入的不同分片。当前的主机回调实现将确保单个设备收集并输出整个操作数,通过单个回调。回调函数应返回整个数组,该数组随后将以单个输入形式发送回发出输出的同一设备。然后,该设备负责将所需的分片发送给其他设备:

with jax.sharding.Mesh(jax.local_devices()[:2], ["d"]):
  pjit.pjit(power3, in_shardings=(P("d"),),
            out_shardings=(P("d"),))(np.array([3., 4.]))

# device=TPU:0 what=x,x^2: ( [3., 4.],
#                            [9., 16.] )

请注意,如果操作数在设备间分片,那么在其中一个设备上收集操作数可能会导致OOM(内存不足)。

在使用 pjit.pjit() 与多主机上的多个设备时,只有设备 0 对应的主机(相对于网格)会接收到回调,回调的参数是从所有主机上的所有参与设备收集的。对于 call(),回调必须返回所有主机上所有设备的完整数组。

在 JAX 自动微分变换下的行为#

在使用 JAX 自动微分变换时,主机回调函数仅对原始值进行操作。考虑以下示例:

def power3(x):
  y = x * x
  # Print both 'x' and 'x^2'. Must pack as a tuple.
  hcb.id_print((x, y), what="x,x^2")
  return y * x

power3(3.)
# what: x,x^2 : (3., 9.)

(你可以在 host_callback_test.HostCallbackTapTest.test_tap_transforms 中看到这些示例的测试。)

当在 jax.jvp() 下使用时,将会有一个仅包含原始值的回调:

jax.jvp(power3, (3.,), (0.1,))
# what: x,x^2 : (3., 9.)

同样地,对于 jax.grad(),我们只从前向计算中得到一个回调:

jax.grad(power3)(3.)
# what: x,x^2 : (3., 9.)

如果你想在执行 jax.jvp() 时调用切线的回调,你可以使用 custom_jvp。例如,你可以定义一个函数,除了它的 custom_jvp 会打印切线之外,它不做任何有趣的事情:

@jax.custom_jvp
def print_tangents(arg):
  return None

@print_tangents.defjvp
def print_tangents_jvp(primals, tangents):
  arg_dot, = tangents
  hcb.id_print(arg_dot, what="tangents")
  return primals, tangents

然后在你想要捕捉切线的位置使用这个函数:

def power3_with_tangents(x):
  y = x * x
  # Print both 'x' and 'x^2'. Must pack as a tuple.
  hcb.id_print((x, y), what="x,x^2")
  print_tangents((x, y))
  return y * x

jax.jvp(power3_with_tangents, (3.,), (0.1,))
# what: x,x^2 : (3., 9.)
# what: tangents : (0.1, 0.6)

你可以在 jax.grad() 期间对余切做类似的事情。这次你必须小心,在计算的其余部分使用你想要获取其余切值的那些值。因此,我们让 print_cotangents 返回其参数:

@jax.custom_vjp
def print_cotangents(arg):
  # Must return the argument for which we want the cotangent.
  return arg

# f_fwd: a -> (b, residual)
def print_cotangents_fwd(arg):
  return print_cotangents(arg), None
# f_bwd: (residual, CT b) -> [CT a]
def print_cotangents_bwd(residual, ct_b):
  hcb.id_print(ct_b, what="cotangents", output_stream=testing_stream)
  return ct_b,

print_cotangents.defvjp(print_cotangents_fwd, print_cotangents_bwd)

def power3_with_cotangents(x):
  y = x * x
  # Print both 'x' and 'x^2'. Must pack as a tuple.
  hcb.id_print((x, y), what="x,x^2", output_stream=testing_stream)
  (x1, y1) = print_cotangents((x, y))
  # Must use the output of print_cotangents
  return y1 * x1

jax.grad(power3_with_cotangents)(3.)
# what: x,x^2 : (3., 9.)
# what: cotangents : (9., 3.)

如果你使用 ad_checkpoint.checkpoint() 来重新生成反向传播的残差,那么原始计算中的回调函数将被调用两次:

jax.grad(lambda x: power3(ad_checkpoint.checkpoint(power3)(x)))(3.)
# what: x,x^2 : (3., 9.)
# what: x,x^2 : (27., 729.)
# what: x,x^2 : (3., 9.)

回调的顺序是:内部 power3 的原始计算,外部 power3 的原始计算,以及内部 power3 的残差重构。

jax.vmap 下的行为#

主机回调函数 id_print()id_tap() 支持向量化转换 jax.vmap()

对于 jax.vmap() ,回调函数的参数是批处理的,并且回调函数会传递一个额外的特殊 transforms ,其中包含一个转换描述符列表,形式为 ("batch", {"batch_dims": ...}) ,其中 ... 表示被点击值的批处理维度(每个参数一个条目,` None` 表示被广播的参数)。

jax.vmap(power3)(np.array([2., 3.])) # 转换: [(‘batch’, {‘batch_dims’: (0, 0)})] 什么: x,x^2 : ([2., 3.], [4., 9.])

参见 id_tap()id_print()call() 的文档。

更多使用示例,请参见 tests/host_callback_test.py。

使用 call() 调用 TensorFlow 函数,支持反向模式自动微分#

主机计算的另一种可能用途是调用为另一个框架(如 TensorFlow)编写的库。在这种情况下,通过使用 jax.custom_vjp() 机制,将主机回调的 JAX 自动微分支持推迟到 TensorFlow 的自动微分机制,变得有趣。

一旦理解了 JAX 自定义 VJP 和 TensorFlow 自动微分机制,这相对容易实现。如何实现这一点的代码显示在 host_callback_to_tf_test.py 中的 call_tf_full_ad 函数中。此示例还支持任意高阶微分。

请注意,如果你只是想从 JAX 调用 TensorFlow 函数,你也可以使用 jax2tf.call_tf 函数

使用 call() 在另一个设备上调用 JAX 函数,支持反向模式自动微分#

我们能够使用主机计算来在另一个设备上调用JAX计算,这应该不足为奇。参数从加速器发送到主机,然后发送到JAX主机计算将运行的外部设备,然后结果被发送回原始加速器。

如何实现这一点的代码展示在 host_callback_test.py 中的 call_jax_other_device 函数

底层细节与调试#

主机回调函数将按照在设备上执行发送操作的顺序,为每个设备执行。

多个设备的宿主回调函数可能会交错执行。设备的数据由JAX运行时管理的独立线程接收(每个设备一个线程)。运行时维护一个可配置大小的缓冲区(参见标志 --jax_host_callback_max_queue_byte_size)。当缓冲区满时,所有接收线程都会暂停,这最终会暂停设备上的计算。运行时为每个设备额外提供一个线程,用于使用接收到的数据调用Python用户函数。如果回调处理速度慢,可能会导致运行时缓冲区填满,最终在设备需要发送数据时暂停计算。有关外发接收器运行时机制的更多详细信息,请参见 运行时代码

为了暂停执行,直到设备上已经开始计算的所有数据到达并被处理完毕,请使用 barrier_wait()

用户定义的回调函数抛出的异常会连同它们的堆栈跟踪一起被记录,但接收线程不会停止。相反,最后一个异常会被记录,如果任何一个tap函数中发生了异常,随后的 barrier_wait() 将会抛出 CallbackException。此异常将包含最后遇到的异常的文本和堆栈跟踪。

对于必须将结果返回给调用源设备的回调函数,例如 call() ,会出现进一步的复杂性。与 TPU 设备相比,这在 CPU/GPU 设备上的处理方式有所不同。

在CPU/GPU设备上,为了避免设备计算因等待永远不会到达的结果而卡住,如果在回调处理过程中发生任何错误(无论是由用户代码本身引发的还是由于返回值与预期的return_shape不匹配),我们会向设备发送一个形状为``int8[12345]``的“假”结果。这将导致设备计算中止,因为接收到的数据与预期的不符。在CPU上,运行时将以一个独特的错误消息崩溃:

` 检查失败:buffer->length() == buffer_length (12345 vs. ...) `

在GPU上,失败更加用户友好,并且会以以下方式呈现给Python程序:

` RET_CHECK 失败 ... 输入源缓冲区形状 s8[12345] ... 不匹配 `

要调试这些消息的根本原因,请参阅调试部分。

在TPU设备上,目前对输入没有形状检查,因此我们在出现错误时采取了更安全的做法,不发送这个假结果。这意味着计算将会挂起,并且不会抛出异常(但回调函数中的任何异常仍会出现在日志中)。

当前的实现使用了XLA提供的出料机制。从某种意义上说,该机制本身相当原始,因为接收器必须确切地知道每个传入数据包的形状,以及预期有多少数据包。这使得在同一计算中使用多种类型的数据变得困难,并且在条件语句或非固定迭代次数的循环中实际上无法使用。此外,直接使用出料机制的代码无法被JAX转换。所有这些限制都通过主机回调函数得到了解决。这里引入的敲击API使得轻松地为多种用途共享出料机制成为可能,同时支持所有转换。

请注意,在使用主机回调函数后,您不能直接使用 lax.outfeed。如果您之后需要使用 lax.outfeed,您可能需要 停止输出接收器()

由于实际调用回调函数是从 C++ 接收器进行的,因此可能难以调试这些调用。特别是,堆栈跟踪将不包括调用代码。您可以使用标志 jax_host_callback_inline``(或环境变量 ``JAX_HOST_CALLBACK_INLINE)来确保对回调的调用是内联的。这仅在调用位于暂存上下文(jit() 或控制流原语)之外时有效。

C++ 接收器 在第一次调用 id_tap() 时会自动启动。为了正确停止它,启动时会注册一个 atexit 处理程序,以调用 barrier_wait() 并使用日志名称“at_exit”。

有一些环境变量可以用来为 C++ outfeed 接收器后端 启用日志记录。

  • TF_CPP_MIN_LOG_LEVEL=0: 将开启INFO日志记录,以下所有内容都需要此设置。

  • TF_CPP_MIN_VLOG_LEVEL=3: 将使所有 VLOG 日志记录到级别 3 的行为像 INFO 日志一样。这可能太多了,但你会看到哪些模块正在记录相关信息,然后你可以选择从哪些模块进行日志记录。

  • TF_CPP_VMODULE=<模块名>=3 (模块名可以是C++或Python,不带扩展名)。

你还应该使用 --verbosity=2 标志,以便查看来自 Python 的日志。

例如,你可以尝试在 host_callback 模块中启用日志记录:TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=host_callback=3 python tests/host_callback_test.py --verbosity=2 HostCallbackIdTapTest.test_tap_jit_simple

如果你想在较低级别的实现模块中启用日志记录,请尝试:TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=outfeed_receiver=3,host_callback=3,outfeed_receiver_py=3,outfeed_thunk=3,infeed_thunk=3,cpu_transfer_manager=3,cpu_runtime=3,xfeed_manager=3,pjrt_client=3 python tests/host_callback_test.py --verbosity=2 HostCallbackIdTapTest.test_tap_jit_simple

(对于bazel测试,使用 –test_arg=–vmodule=…)

待办事项:
  • 更多性能测试。

  • 探索使用外部编译实现TPU。

  • 探索使用 XLA CustomCall 进行 CPU 和 GPU 的实现。

API#

id_tap(tap_func, arg, *[, result, ...])

主机回调水龙头原语,类似于带有对 tap_func 调用的标识函数。

id_print(arg, *[, result, tap_with_device, ...])

类似于带有打印功能的 id_tap()

call(callback_func, arg, *[, result_shape, ...])

向主机发出调用,并期待一个结果。

barrier_wait([logging_name])

阻塞调用线程,直到所有当前的外流数据处理完毕。

CallbackException

信号表示某些回调函数发生了异常。