jax.experimental.host_callback.id_tap

目录

jax.experimental.host_callback.id_tap#

jax.experimental.host_callback.id_tap(tap_func, arg, *, result=None, tap_with_device=False, device_index=0, callback_flavor=CallbackFlavor.IO_CALLBACK, **kwargs)[源代码]#

主机回调水龙头原语,类似于带有对 tap_func 调用的标识函数。

警告

自2024年3月20日起,host_callback API 已被弃用。该功能已被 新的 JAX 外部回调 所取代。详情请参见 google/jax#20385

id_tap 在语义上表现得像恒等函数,但具有一个副作用,即用户定义的Python函数会以参数的运行时值被调用。

参数:
  • tap_func – 调用 tap_func(arg, transforms) 形式的 tap 函数,其中 arg 如下面所述,而 transforms 是以 (name, params) 形式应用的 JAX 变换序列。如果可选参数 tap_with_device 为 True,那么调用还包括作为关键字参数的值被 tap 的设备:tap_func(arg, transforms, device=dev)

  • arg – 传递给 tap 函数的参数可以是 JAX 类型的 pytree。

  • result – 如果指定,则指定 id_tap 的返回值。此值不会传递给 tap 函数,实际上也不会从设备发送到主机。如果未指定 result 参数,则 id_tap 的返回值为 arg

  • tap_with_device – 如果为真,则调用tap函数,并将tap来源的设备作为关键字参数传递。

  • device_index – 指定在SPMD程序中从哪个设备调用tap函数。仅在使用outfeed实现机制时有效,即在CPU上无效,除非使用–jax_host_callback_outfeed=True。

  • callback_flavor – 如果使用 JAX_HOST_CALLBACK_LEGACY=False 运行,则指定要使用的回调风格。参见 google/jax#20385

返回:

arg,如果给出则使用 result

执行顺序是根据数据依赖性:在所有参数和 result 的值(如果存在)计算之后,以及在返回值被使用之前。id_tap 的至少一个返回值必须在计算的其余部分中使用,否则此操作无效。

点击操作甚至适用于在加速器上执行的代码,甚至适用于在JAX变换下的代码。

更多详情请参阅 jax.experimental.host_callback 模块文档。