jax.纯回调

jax.纯回调#

jax.pure_callback(callback, result_shape_dtypes, *args, sharding=None, vectorized=False, **kwargs)[源代码][源代码]#

调用一个纯Python回调。在 jit()/vmap()/等情况下工作。

更多解释,请参见 外部回调

pure_callback 允许在即时编译的 JAX 函数中调用 Python 函数。输入的 callback 将被传递放置在本地 CPU 上的 JAX 数组,并且它也应该返回 CPU 上的 JAX 数组。

回调被视为功能上纯的,这意味着它没有副作用,并且其输出值仅取决于其参数值。因此,它可以安全地被多次调用(例如,当通过 vmap()pmap() 转换时),或者在例如 jit 装饰的函数的输出对其值没有数据依赖时根本不被调用。如果数据依赖允许,纯回调也可能被重新排序。

vmap 被应用时,行为将取决于 vectorized 关键字参数的值。当 vectorizedTrue 时,假定回调函数遵守 jax.vmap(callback)(xs) == callback(xs) == jnp.stack([callback(x) for x in xs])。因此,回调函数将直接在批量输入上调用(其中批量轴是前导维度)。此外,回调函数应返回具有相应前导批量轴的输出。如果未向量化,callback 将按顺序映射到批量轴上。例如,如果 callback = lambda x, y: np.matmul(x, y),那么我们可以自由地设置 vectorized=True,因为 np.matmul 函数处理任意前导批量维度。

参数:
  • callback (Callable[..., Any]) – 在主机上执行的函数。假定回调是一个纯函数(即没有副作用的函数):如果传递了一个非纯函数,它可能会以意想不到的方式运行,特别是在转换过程中。可调用对象将被传递数组的PyTrees作为参数,并且应该返回与``result_shape_dtypes``匹配的数组的PyTree。

  • result_shape_dtypes (Any) – pytree,其叶子具有 shapedtype 属性,其结构与回调函数在运行时的预期输出相匹配。jax.ShapeDtypeStruct 常用于定义叶子值。

  • *args (Any) – 传递给回调函数的参数

  • sharding (SingleDeviceSharding | None) – 可选的分片,指定应从哪个设备调用回调。

  • vectorized (bool) – 布尔值,指定回调函数是否可以以矢量化方式操作。

  • **kwargs (Any) – 传递给回调函数的键值参数

返回:

一个 jax.Array 对象的 pytree,其结构与 result_shape_dtypes 相匹配。

返回类型:

result

参见