jax.extend.ffi.ffi_call#
- jax.extend.ffi.ffi_call(target_name, result_shape_dtypes, *args, vectorized=False, **kwargs)[源代码][源代码]#
调用外部函数接口 (FFI) 目标。
与
pure_callback()类似,ffi_call在vmap()下的行为取决于vectorized的值。当vectorized为True时,假设 FFI 目标满足:ffi_call(xs) == jnp.stack([ffi_call(x) for x in xs])。换句话说,使用额外的领先维度调用 FFI 目标应返回与在循环内调用并沿零轴堆叠相同的结果。因此,FFI 目标将直接在批量输入上调用(其中批量轴是领先维度)。此外,回调应返回具有相应领先批量轴的输出。如果vectorized为False``(默认行为),在 :func:`~jax.vmap` 下转换此 ``ffi_call将导致在主体中带有ffi_call的scan()。- 参数:
target_name (str) – 使用
register_custom_call_target()注册的 XLA FFI 自定义调用目标的名称。result_shape_dtypes (ResultMetadata | Sequence[ResultMetadata]) – 一个对象,或对象序列,具有
shape和dtype属性,这些属性应与自定义调用输出或输出的形状和数据类型匹配。ShapeDtypeStruct常用于定义result_shape_dtypes的元素。jax.core.abstract_token可用于表示一个令牌类型的输出。*args (ArrayLike) – 传递给自定义调用的参数。
vectorized (bool) – 布尔值,指定回调函数是否可以以向量化的方式操作,如上所述。
**kwargs (Any) – 作为使用XLA的FFI接口传递给自定义调用的命名属性的关键字参数。
- 返回:
一个或多个
Array对象,其形状和数据类型与result_shape_dtypes匹配。- 返回类型: