jax.experimental.host_callback.id_tap#
- jax.experimental.host_callback.id_tap(tap_func, arg, *, result=None, tap_with_device=False, device_index=0, callback_flavor=CallbackFlavor.IO_CALLBACK, **kwargs)[源代码]#
主机回调水龙头原语,类似于带有对
tap_func
调用的标识函数。警告
自2024年3月20日起,host_callback API 已被弃用。该功能已被 新的 JAX 外部回调 所取代。详情请参见 google/jax#20385。
id_tap
在语义上表现得像恒等函数,但具有一个副作用,即用户定义的Python函数会以参数的运行时值被调用。- 参数:
tap_func – 调用
tap_func(arg, transforms)
形式的 tap 函数,其中arg
如下面所述,而transforms
是以(name, params)
形式应用的 JAX 变换序列。如果可选参数 tap_with_device 为 True,那么调用还包括作为关键字参数的值被 tap 的设备:tap_func(arg, transforms, device=dev)
。arg – 传递给 tap 函数的参数可以是 JAX 类型的 pytree。
result – 如果指定,则指定
id_tap
的返回值。此值不会传递给 tap 函数,实际上也不会从设备发送到主机。如果未指定result
参数,则id_tap
的返回值为arg
。tap_with_device – 如果为真,则调用tap函数,并将tap来源的设备作为关键字参数传递。
device_index – 指定在SPMD程序中从哪个设备调用tap函数。仅在使用outfeed实现机制时有效,即在CPU上无效,除非使用–jax_host_callback_outfeed=True。
callback_flavor – 如果使用 JAX_HOST_CALLBACK_LEGACY=False 运行,则指定要使用的回调风格。参见 google/jax#20385。
- 返回:
arg
,如果给出则使用result
。
执行顺序是根据数据依赖性:在所有参数和
result
的值(如果存在)计算之后,以及在返回值被使用之前。id_tap
的至少一个返回值必须在计算的其余部分中使用,否则此操作无效。点击操作甚至适用于在加速器上执行的代码,甚至适用于在JAX变换下的代码。
更多详情请参阅
jax.experimental.host_callback
模块文档。