jax.lax.优化屏障

jax.lax.优化屏障#

jax.lax.optimization_barrier(operand, /)[源代码][源代码]#

防止编译器将操作跨越屏障移动。

优化屏障有多种可能的用途:

  • 优化屏障确保所有输入在依赖于屏障输出的任何操作符之前被评估。这可以用来强制执行特定的操作顺序。

  • 优化屏障防止公共子表达式消除。JAX 使用此功能来实现重计算。

  • 优化屏障阻止编译器融合。也就是说,屏障之前的操作可能不会被编译器融合到与屏障之后的操作相同的内核中。

JAX 没有为优化屏障定义导数或批处理规则。

优化屏障在编译函数之外无效。

参数:

operand – JAX 值的 pytree。

返回:

JAX 值的 pytree,具有与 operand 相同的结构和内容。

示例

防止在两次调用 sin 之间进行公共子表达式消除:

>>> def f(x):
...   return jax.lax.optimization_barrier(jax.lax.sin(x)) + jax.lax.sin(x)
>>> jax.jit(f)(0.)
Array(0., dtype=float32, weak_type=True)