jax.jit#
- jax.jit(fun, in_shardings=UnspecifiedValue, out_shardings=UnspecifiedValue, static_argnums=None, static_argnames=None, donate_argnums=None, donate_argnames=None, keep_unused=False, device=None, backend=None, inline=False, abstracted_axes=None)[源代码][源代码]#
为
fun
设置即时编译(JIT)与 XLA。- 参数:
fun (Callable) – 要进行即时编译的函数。
fun
应该是一个纯函数。fun
的参数和返回值应该是数组、标量或(嵌套的)标准 Python 容器(元组/列表/字典)。由static_argnums
指示的位置参数可以是任何可哈希类型。静态参数作为编译缓存键的一部分包含在内,这就是为什么必须定义哈希和相等运算符的原因。JAX 保留对fun
的弱引用,作为编译缓存键使用,因此fun
对象必须是可弱引用的。in_shardings – 可选的,一个
Sharding
或包含Sharding
叶子和结构的 pytree,该结构是传递给fun
的位置参数元组的树前缀。如果提供,传递给fun
的位置参数的分片必须与in_shardings
兼容,否则会引发错误,并且编译的计算具有与in_shardings
相对应的输入分片。如果未提供,编译的计算的输入分片将从参数分片中推断。out_shardings – 可选的,一个
Sharding
或带有Sharding
叶子和结构的 pytree,该结构是fun
输出的树前缀。如果提供,它与将相应的jax.lax.with_sharding_constraint()
应用于fun
的输出具有相同的效果。static_argnums (int | Sequence[int] | None) – 可选的,一个整数或整数集合,指定哪些位置参数作为静态参数(跟踪和编译时常量)。静态参数应该是可哈希的,这意味着实现了
__hash__
和__eq__
,并且是不可变的。否则它们可以是任意的 Python 对象。使用这些常量的不同值调用 jitted 函数将触发重新编译。不是类数组或其容器的参数必须标记为静态。如果没有提供static_argnums
或static_argnames
,则没有参数被视为静态。如果未提供static_argnums
但提供了static_argnames
,或者反之亦然,JAX 使用inspect.signature(fun)
查找与static_argnames
对应的位置参数(或反之亦然)。如果同时提供了static_argnums
和static_argnames
,则不使用inspect.signature
,并且只有列在static_argnums
或static_argnames
中的实际参数将被视为静态。static_argnames (str | Iterable[str] | None) – 可选,一个字符串或字符串集合,指定哪些命名参数作为静态(编译时常量)处理。详情请参阅
static_argnums
的注释。如果未提供但static_argnums
已设置,默认值基于调用inspect.signature(fun)
来查找相应的命名参数。donate_argnums (int | Sequence[int] | None) – 可选的整数集合,用于指定哪些位置参数缓冲区可以被计算覆盖并在调用者中标记为删除。如果你在计算开始后不再需要这些缓冲区,那么捐赠它们是安全的。在某些情况下,XLA可以利用捐赠的缓冲区来减少执行计算所需的内存量,例如将你的一个输入缓冲区回收用于存储结果。你不应该重用捐赠给计算的缓冲区;如果你尝试这样做,JAX会引发错误。默认情况下,不捐赠任何参数缓冲区。如果没有提供``donate_argnums``或``donate_argnames``,则不捐赠任何参数。如果没有提供``donate_argnums``但提供了``donate_argnames``,或者反之,JAX使用:code:inspect.signature(fun)`来查找与``donate_argnames``(或反之)对应的位置参数。如果同时提供了``donate_argnums``和``donate_argnames`,则不使用``inspect.signature``,并且只有列在``donate_argnums``或``donate_argnames``中的实际参数会被捐赠。有关缓冲区捐赠的更多详细信息,请参阅`FAQ <https://jax.readthedocs.io/en/latest/faq.html#buffer-donation>`_。
donate_argnames (str | Iterable[str] | None) – 可选,指定哪些命名参数被捐赠给计算的字符串或字符串集合。详情请参见
donate_argnums
的注释。如果未提供但donate_argnums
已设置,默认值基于调用inspect.signature(fun)
来查找相应的命名参数。keep_unused (bool) – 可选的布尔值。如果为 False`(默认值),JAX 确定为 `fun 未使用的参数 可能 会从生成的编译 XLA 可执行文件中删除。这些参数将不会被传输到设备,也不会提供给底层可执行文件。如果为 True,未使用的参数将不会被修剪。
device (xc.Device | None) – 这是一个实验性功能,API 可能会发生变化。可选,jit 函数将运行的设备。(可用设备可以通过
jax.devices()
获取。)默认值继承自 XLA 的 DeviceAssignment 逻辑,通常是使用jax.devices()[0]
。backend (str | None) – 这是一个实验性功能,API 可能会发生变化。可选,表示 XLA 后端的字符串:
'cpu'
、'gpu'
或'tpu'
。inline (bool) – 可选的布尔值。指定此函数是否应内联到封闭的 jaxprs 中。默认为 False。
abstracted_axes (Any | None)
- 返回:
fun
的包装版本,设置为即时编译。- 返回类型:
pjit.JitWrapped
示例
在以下示例中,
selu
可以通过 XLA 编译成一个单一的融合内核:>>> import jax >>> >>> @jax.jit ... def selu(x, alpha=1.67, lmbda=1.05): ... return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x) - alpha) >>> >>> key = jax.random.key(0) >>> x = jax.random.normal(key, (10,)) >>> print(selu(x)) [-0.54485 0.27744 -0.29255 -0.91421 -0.62452 -0.24748 -0.85743 -0.78232 0.76827 0.59566 ]
在装饰函数时传递诸如
static_argnames
的参数,一个常见的模式是使用functools.partial()
:>>> from functools import partial >>> >>> @partial(jax.jit, static_argnames=['n']) ... def g(x, n): ... for i in range(n): ... x = x ** 2 ... return x >>> >>> g(jnp.arange(4), 3) Array([ 0, 1, 256, 6561], dtype=int32)