jax.lax.fori_loop

目录

jax.lax.fori_loop#

jax.lax.fori_loop(lower, upper, body_fun, init_val, *, unroll=None)[源代码][源代码]#

lowerupper 通过归约为 jax.lax.while_loop() 进行循环。

简而言之,Haskell-like type signature

fori_loop :: Int -> Int -> ((Int, a) -> a) -> a -> a

fori_loop 的语义由以下 Python 实现给出:

def fori_loop(lower, upper, body_fun, init_val):
  val = init_val
  for i in range(lower, upper):
    val = body_fun(i, val)
  return val

如Python版本所示,设置 upper <= lower 将不会产生迭代。不支持负数或自定义增量。

与那个Python版本不同,fori_loop 是基于对 jax.lax.while_loop()jax.lax.scan() 的调用来实现的。如果迭代次数是静态的(即在跟踪时已知,可能是因为 lowerupper 是Python整数字面量),那么 fori_loop 是基于 scan() 实现的,并且支持反向模式自动微分;否则,将使用 while_loop,并且不支持反向模式自动微分。有关更多信息,请参阅这些函数的文档字符串。

与Python的类似物不同,循环传递的值 val 在所有迭代中必须保持固定的形状和dtype(而不仅仅是在NumPy的秩/形状广播和dtype提升规则下保持一致)。换句话说,类型签名中的类型 a 表示具有固定形状和dtype的数组(或具有固定结构和固定形状和dtype数组的嵌套元组/列表/字典容器数据结构)。

备注

fori_loop() 编译了 body_fun ,因此虽然它可以与 jit() 结合使用,但通常没有必要。

参数:
  • lower – 一个表示循环索引下限(包含)的整数

  • upper – 一个表示循环索引上限(不包括)的整数

  • body_fun – 类型为 (int, a) -> a 的函数。

  • init_val – 类型为 a 的初始循环进位值。

  • unroll (int | bool | None) – 一个可选的整数或布尔值,用于确定循环展开的程度。如果提供了一个整数,它决定了在循环的单次迭代中运行多少次展开的循环迭代。如果提供了一个布尔值,它将决定循环是完全展开(即 unroll=True)还是完全不展开(即 unroll=False)。此参数仅在循环边界是静态已知的情况下适用。

返回:

从最终迭代中获取的循环值,类型为 a