jax.experimental.io_回调

目录

jax.experimental.io_回调#

jax.experimental.io_callback(callback, result_shape_dtypes, *args, sharding=None, ordered=False, **kwargs)[源代码][源代码]#

调用一个不纯的Python回调函数。

更多解释,请参见 外部回调

参数:
  • callback (Callable[..., Any]) – 在主机上执行的函数。假定它是一个不纯的函数。如果 callback 是纯函数,使用 jax.pure_callback() 可能会导致更高效的执行。

  • result_shape_dtypes (Any) – pytree,其叶子具有 shapedtype 属性,其结构与回调函数在运行时的预期输出相匹配。jax.ShapeDtypeStruct 常用于定义叶子值。

  • *args (Any) – 传递给回调函数的参数

  • sharding (SingleDeviceSharding | None) – 可选的分片,指定应从哪个设备调用回调。

  • ordered (bool) – 布尔值,指定对回调的连续调用是否必须按顺序进行。

  • **kwargs (Any) – 传递给回调函数的键值参数

返回:

一个 jax.Array 对象的 pytree,其结构与 result_shape_dtypes 相匹配。

返回类型:

result

参见