jax.experimental.host_callback.call

目录

jax.experimental.host_callback.call#

jax.experimental.host_callback.call(callback_func, arg, *, result_shape=None, call_with_device=False, device_index=0, callback_flavor=CallbackFlavor.IO_CALLBACK)[源代码]#

向主机发出调用,并期待一个结果。

警告

自2024年3月20日起,host_callback API 已被弃用。该功能已被 新的 JAX 外部回调 所取代。详情请参见 google/jax#20385

参数:
  • callback_func (Callable) – 在主机上调用的 Python 函数为 callback_func(arg)。如果 call_with_device 可选参数为 True,那么调用还包括来自调用源的 device kwarg:callback_func(arg, device=dev)。此函数必须返回一个 numpy ndarrays 的 pytree。

  • arg – 传递给回调函数的参数可以是 JAX 类型的 pytree。

  • result_shape – 描述预期结果形状和数据类型的值。这可以是一个数值标量,从中获取形状和数据类型,或者是一个具有 .shape.dtype 属性的对象。如果回调的结果是一个 pytree,那么 result_shape 也应该是一个具有相同结构的 pytree。特别是,如果函数没有任何结果,result_shape 可以是 ()None。包含 call 的设备代码是根据预期的结果形状和数据类型编译的,如果在运行时 callback_func 调用返回了不同类型的结果,将会引发错误。

  • call_with_device – 如果为真,则回调函数会以调用来源的设备作为关键字参数被调用。

  • device_index – 指定在SPMD程序中从哪个设备调用tap函数。仅在使用outfeed实现机制时有效,即在CPU上无效,除非使用–jax_host_callback_outfeed=True。

  • callback_flavor – 如果使用 JAX_HOST_CALLBACK_LEGACY=False 运行,则指定要使用的回调风格。参见 google/jax#20385

返回:

callback_func 调用的结果。

更多详情请参阅 jax.experimental.host_callback 模块文档。