jax.lax.while_循环#
- jax.lax.while_loop(cond_fun, body_fun, init_val)[源代码][源代码]#
在
cond_fun
为 True 时,循环调用body_fun
。简而言之,Haskell-like type signature 是
while_loop :: (a -> Bool) -> (a -> a) -> a -> a
while_loop
的语义由以下 Python 实现给出:def while_loop(cond_fun, body_fun, init_val): val = init_val while cond_fun(val): val = body_fun(val) return val
与那个Python版本不同,
while_loop
是JAX的一个原语,并且被降低为一个单一的WhileOp。这使得它在减少jit编译函数的编译时间方面很有用,因为在``@jit``函数中的原生Python循环结构会被展开,导致大型XLA计算。与Python的类似物不同,循环传递的值
val
在所有迭代中必须保持固定的形状和dtype(而不仅仅是在NumPy的秩/形状广播和dtype提升规则下保持一致)。换句话说,类型签名中的类型a
表示具有固定形状和dtype的数组(或具有固定结构和固定形状和dtype数组的嵌套元组/列表/字典容器数据结构)。与使用Python原生循环结构相比,另一个不同之处在于
while_loop
不支持反向模式微分,因为XLA计算要求内存需求具有静态边界。备注
while_loop()
编译了cond_fun
和body_fun
,因此虽然它可以与jit()
结合使用,但通常没有必要。- 参数:
cond_fun (Callable[[T], BooleanNumeric]) – 类型为
a -> Bool
的函数。body_fun (Callable[[T], T]) – 类型为
a -> a
的函数。init_val (T) – 类型
a
的值,该类型可以是标量、数组或任何 pytree(嵌套的 Python 元组/列表/字典),表示初始循环携带值。
- 返回:
来自 body_fun 最终迭代的输出,类型为
a
。- 返回类型:
T