jax.extend.ffi.ffi_call#
- jax.extend.ffi.ffi_call(target_name, result_shape_dtypes, *args, vectorized=False, **kwargs)[源代码][源代码]#
调用外部函数接口 (FFI) 目标。
与
pure_callback()
类似,ffi_call
在vmap()
下的行为取决于vectorized
的值。当vectorized
为True
时,假设 FFI 目标满足:ffi_call(xs) == jnp.stack([ffi_call(x) for x in xs])
。换句话说,使用额外的领先维度调用 FFI 目标应返回与在循环内调用并沿零轴堆叠相同的结果。因此,FFI 目标将直接在批量输入上调用(其中批量轴是领先维度)。此外,回调应返回具有相应领先批量轴的输出。如果vectorized
为False``(默认行为),在 :func:`~jax.vmap` 下转换此 ``ffi_call
将导致在主体中带有ffi_call
的scan()
。- 参数:
target_name (str) – 使用
register_custom_call_target()
注册的 XLA FFI 自定义调用目标的名称。result_shape_dtypes (ResultMetadata | Sequence[ResultMetadata]) – 一个对象,或对象序列,具有
shape
和dtype
属性,这些属性应与自定义调用输出或输出的形状和数据类型匹配。ShapeDtypeStruct
常用于定义result_shape_dtypes
的元素。jax.core.abstract_token
可用于表示一个令牌类型的输出。*args (ArrayLike) – 传递给自定义调用的参数。
vectorized (bool) – 布尔值,指定回调函数是否可以以向量化的方式操作,如上所述。
**kwargs (Any) – 作为使用XLA的FFI接口传递给自定义调用的命名属性的关键字参数。
- 返回:
一个或多个
Array
对象,其形状和数据类型与result_shape_dtypes
匹配。- 返回类型: