jax.lax.map

目录

jax.lax.map#

jax.lax.map(f, xs, *, batch_size=None)[源代码][源代码]#

对数组的前导轴应用函数。

类似于Python的内置map,不同之处在于输入和输出是以堆叠数组的形式。除非你需要逐元素应用函数以减少内存使用或与其他控制流原语进行异构计算,否则请考虑使用 vmap() 变换。

xs 是一个数组类型时,map() 的语义由这个 Python 实现给出:

def map(f, xs):
  return np.stack([f(x) for x in xs])

类似于 scan()map() 也是基于 JAX 原语实现的,因此许多与 Python 循环相比的相同优势适用:xs 可以是任意嵌套的 pytree 类型,并且映射的计算只编译一次。

如果提供了 batch_size ,计算将以该大小的批次执行,并使用 vmap() 进行并行化。这可以作为 map 的更高性能版本或 vmap 的内存高效版本使用。如果轴不能被批次大小整除,剩余部分将在单独的 vmap 中处理,并连接到结果中。

>>> x = jnp.ones((10, 3, 4))
>>> def f(x):
...   print('inner shape:', x.shape)
...   return x + 1
>>> y = lax.map(f, x, batch_size=3)
inner shape: (3, 4)
inner shape: (3, 4)
>>> y.shape
(10, 3, 4)

在上面的例子中,“内部形状”被打印了两次,一次是在追踪批量计算时,一次是在追踪剩余计算时。

参数:
  • f – 一个Python函数,用于在 xs 的第一个轴或多个轴上逐元素应用。

  • xs – 沿主要轴映射的值。

  • batch_size (int | None) – (可选) 整数,指定每个步骤并行执行的批次大小。

返回:

映射值。