jax.typing
模块#
JAX 类型模块是 JAX 特定静态类型注释的所在地。这个子模块正在开发中;要查看此处导出的类型的提案,请参见 https://jax.readthedocs.io/en/latest/jep/12049-type-annotations.html。
当前可用的类型有:
jax.Array
: 用于任何 JAX 数组或 tracer 的注解(即在 JAX 变换中的数组表示)。jax.typing.ArrayLike
: annotation for any value that is safe to implicitly cast to a JAX array; this includesjax.Array
,numpy.ndarray
, as well as Python builtin numeric values (e.g.int
,float
, etc.) and numpy scalar values (e.g.numpy.int32
,numpy.flota64
, etc.)jax.typing.DTypeLike
: annotation for any value that can be cast to a JAX-compatible dtype; this includes strings (e.g. ‘float32’, ‘int32’), scalar types (e.g. float, np.float32), dtypes (e.g. np.dtype(‘float32’)), or objects with a dtype attribute (e.g. jnp.float32, jnp.int32).
我们可能在未来的版本中添加额外的类型。
JAX 类型注解最佳实践#
在公共API函数中注释JAX数组时,我们建议对数组输入使用 ArrayLike
,对数组输出使用 Array
。
例如,你的函数可能看起来像这样:
import numpy as np
import jax.numpy as jnp
from jax import Array
from jax.typing import ArrayLike
def my_function(x: ArrayLike) -> Array:
# Runtime type validation, Python 3.10 or newer:
if not isinstance(x, ArrayLike):
raise TypeError(f"Expected arraylike input; got {x}")
# Runtime type validation, any Python version:
if not (isinstance(x, (np.ndarray, Array)) or np.isscalar(x)):
raise TypeError(f"Expected arraylike input; got {x}")
# Convert input to jax.Array:
x_arr = jnp.asarray(x)
# ... do some computation; JAX functions will return Array types:
result = x_arr.sum(0) / x_arr.shape[0]
# return an Array
return result
JAX 的大多数公共 API 都遵循这种模式。特别要注意的是,我们建议 JAX 函数不接受诸如 list
或 tuple
这样的序列来代替数组,因为这可能会在 JAX 变换(如 jit()
)中造成额外的开销,并且在批量变换(如 vmap()
或 jax.pmap()
)中可能会出现意外行为。有关更多信息,请参阅 非数组输入 NumPy 与 JAX