jax.extend.ffi.ffi_call

目录

jax.extend.ffi.ffi_call#

jax.extend.ffi.ffi_call(target_name, result_shape_dtypes, *args, vectorized=False, **kwargs)[源代码][源代码]#

调用外部函数接口 (FFI) 目标。

pure_callback() 类似,ffi_callvmap() 下的行为取决于 vectorized 的值。当 vectorizedTrue 时,假设 FFI 目标满足:ffi_call(xs) == jnp.stack([ffi_call(x) for x in xs])。换句话说,使用额外的领先维度调用 FFI 目标应返回与在循环内调用并沿零轴堆叠相同的结果。因此,FFI 目标将直接在批量输入上调用(其中批量轴是领先维度)。此外,回调应返回具有相应领先批量轴的输出。如果 vectorizedFalse``(默认行为),在 :func:`~jax.vmap` 下转换此 ``ffi_call 将导致在主体中带有 ffi_callscan()

参数:
  • target_name (str) – 使用 register_custom_call_target() 注册的 XLA FFI 自定义调用目标的名称。

  • result_shape_dtypes (ResultMetadata | Sequence[ResultMetadata]) – 一个对象,或对象序列,具有 shapedtype 属性,这些属性应与自定义调用输出或输出的形状和数据类型匹配。ShapeDtypeStruct 常用于定义 result_shape_dtypes 的元素。jax.core.abstract_token 可用于表示一个令牌类型的输出。

  • *args (ArrayLike) – 传递给自定义调用的参数。

  • vectorized (bool) – 布尔值,指定回调函数是否可以以向量化的方式操作,如上所述。

  • **kwargs (Any) – 作为使用XLA的FFI接口传递给自定义调用的命名属性的关键字参数。

返回:

一个或多个 Array 对象,其形状和数据类型与 result_shape_dtypes 匹配。

返回类型:

Array | list[Array]