jax.lax.while_循环

目录

jax.lax.while_循环#

jax.lax.while_loop(cond_fun, body_fun, init_val)[源代码][源代码]#

cond_fun 为 True 时,循环调用 body_fun

简而言之,Haskell-like type signature

while_loop :: (a -> Bool) -> (a -> a) -> a -> a

while_loop 的语义由以下 Python 实现给出:

def while_loop(cond_fun, body_fun, init_val):
  val = init_val
  while cond_fun(val):
    val = body_fun(val)
  return val

与那个Python版本不同,while_loop 是JAX的一个原语,并且被降低为一个单一的WhileOp。这使得它在减少jit编译函数的编译时间方面很有用,因为在``@jit``函数中的原生Python循环结构会被展开,导致大型XLA计算。

与Python的类似物不同,循环传递的值 val 在所有迭代中必须保持固定的形状和dtype(而不仅仅是在NumPy的秩/形状广播和dtype提升规则下保持一致)。换句话说,类型签名中的类型 a 表示具有固定形状和dtype的数组(或具有固定结构和固定形状和dtype数组的嵌套元组/列表/字典容器数据结构)。

与使用Python原生循环结构相比,另一个不同之处在于 while_loop 不支持反向模式微分,因为XLA计算要求内存需求具有静态边界。

备注

while_loop() 编译了 cond_funbody_fun,因此虽然它可以与 jit() 结合使用,但通常没有必要。

参数:
  • cond_fun (Callable[[T], BooleanNumeric]) – 类型为 a -> Bool 的函数。

  • body_fun (Callable[[T], T]) – 类型为 a -> a 的函数。

  • init_val (T) – 类型 a 的值,该类型可以是标量、数组或任何 pytree(嵌套的 Python 元组/列表/字典),表示初始循环携带值。

返回:

来自 body_fun 最终迭代的输出,类型为 a

返回类型:

T