jax.lax.fori_loop#
- jax.lax.fori_loop(lower, upper, body_fun, init_val, *, unroll=None)[源代码][源代码]#
从
lower
到upper
通过归约为jax.lax.while_loop()
进行循环。简而言之,Haskell-like type signature 是
fori_loop :: Int -> Int -> ((Int, a) -> a) -> a -> a
fori_loop
的语义由以下 Python 实现给出:def fori_loop(lower, upper, body_fun, init_val): val = init_val for i in range(lower, upper): val = body_fun(i, val) return val
如Python版本所示,设置
upper <= lower
将不会产生迭代。不支持负数或自定义增量。与那个Python版本不同,
fori_loop
是基于对jax.lax.while_loop()
或jax.lax.scan()
的调用来实现的。如果迭代次数是静态的(即在跟踪时已知,可能是因为lower
和upper
是Python整数字面量),那么fori_loop
是基于scan()
实现的,并且支持反向模式自动微分;否则,将使用while_loop
,并且不支持反向模式自动微分。有关更多信息,请参阅这些函数的文档字符串。与Python的类似物不同,循环传递的值
val
在所有迭代中必须保持固定的形状和dtype(而不仅仅是在NumPy的秩/形状广播和dtype提升规则下保持一致)。换句话说,类型签名中的类型a
表示具有固定形状和dtype的数组(或具有固定结构和固定形状和dtype数组的嵌套元组/列表/字典容器数据结构)。备注
fori_loop()
编译了body_fun
,因此虽然它可以与jit()
结合使用,但通常没有必要。- 参数:
- 返回:
从最终迭代中获取的循环值,类型为
a
。