jax.pmap#
- jax.pmap(fun, axis_name=None, *, in_axes=0, out_axes=0, static_broadcasted_argnums=(), devices=None, backend=None, axis_size=None, donate_argnums=(), global_arg_shapes=None)[源代码][源代码]#
支持集体操作的并行映射。
pmap
的目的是表达单程序多数据(SPMD)程序。将pmap
应用于一个函数将使用XLA编译该函数(类似于jit
),然后在XLA设备上并行执行,例如多个GPU或多个TPU核心。从语义上讲,它与vmap
相当,因为这两种变换都在数组轴上对函数进行映射,但vmap
通过将映射轴推入基本操作来向量化函数,而pmap
则复制函数并在其自己的XLA设备上并行执行每个副本。映射的轴大小必须小于或等于可用的本地XLA设备的数量,如
jax.local_device_count()
返回的(除非指定了devices
,见下文)。对于嵌套的pmap()
调用,映射轴大小的乘积必须小于或等于XLA设备的数量。pmap()
要求所有参与的设备都是相同的。例如,不可能使用pmap()
在两种不同型号的 GPU 上并行计算。对于同一设备在同一 pmap 中参与两次,目前是一个错误。多进程平台: 在多进程平台如 TPU pods 上,
pmap()
设计用于 SPMD Python 程序,其中每个进程都运行相同的 Python 代码,使得所有进程按相同顺序运行相同的 pmapped 函数。每个进程仍应使用映射轴大小等于 本地 设备数量的参数调用 pmapped 函数(除非指定了devices
,见下文),并且将像往常一样返回具有相同前导轴大小的数组。然而,fun
中的任何集体操作将通过设备间通信在 所有 参与设备上计算,包括其他进程上的设备。从概念上讲,这可以被视为在一个跨进程分片的单个数组上运行 pmap,其中每个进程“看到”的只是其本地输入和输出的分片。SPMD 模型要求所有设备上必须按相同顺序运行相同的多进程 pmaps,但它们可以与单个进程中运行的任意操作交错。- 参数:
fun (Callable) – 要映射到参数轴上的函数。它的参数和返回值应为数组、标量或(嵌套的)标准 Python 容器(元组/列表/字典)。由
static_broadcasted_argnums
指示的位置参数可以是任何东西,前提是它们是可哈希的并且定义了相等操作。axis_name (AxisName | None) – 可选的,一个可哈希的Python对象,用于标识映射的轴,以便可以应用并行集体操作。
in_axes – 一个非负整数、None 或嵌套的 Python 容器,指定要映射的位置参数的轴。作为关键字传递的参数总是映射在其前导轴(即轴索引 0)上。详情请参见
vmap()
。out_axes – 一个非负整数、None,或嵌套的Python容器,指示映射轴应在输出中的位置。所有具有映射轴的输出必须具有非None的
out_axes
规范(参见vmap()
)。static_broadcasted_argnums (int | Iterable[int]) – 一个整数或整数的集合,指定哪些位置参数作为静态(编译时常量)处理。仅依赖于静态参数的操作将被常量折叠。使用不同的静态常量值调用pmapped函数将触发重新编译。如果调用pmapped函数时提供的位置参数少于``static_broadcasted_argnums``指示的数量,则会引发错误。每个静态参数将被广播到所有设备。不是数组或其容器的参数必须标记为静态。默认值为()。静态参数必须是可哈希的,这意味着``__hash__``和``__eq__``都已实现,并且应该是不可变的。
devices (Sequence[xc.Device] | None) – 这是一个实验性功能,API 可能会发生变化。可选,一个设备序列以进行映射。(可用设备可以通过 jax.devices() 获取)。在多进程设置中,必须为每个进程提供相同的设备序列(因此将包括跨进程的设备)。如果指定,映射轴的大小必须等于给定进程的序列中设备的数量。在内部或外部
pmap()
中指定devices
的嵌套pmap()
尚未支持。backend (str | None) – 这是一个实验性功能,API 可能会发生变化。可选,表示 XLA 后端的字符串。’cpu’、’gpu’ 或 ‘tpu’。
axis_size (int | None) – 可选;映射轴的大小。
donate_argnums (int | Iterable[int]) – 指定哪些位置参数缓冲区被“捐赠”给计算。如果你在计算完成后不再需要这些缓冲区,捐赠它们是安全的。在某些情况下,XLA 可以利用捐赠的缓冲区来减少执行计算所需的内存量,例如将输入缓冲区之一回收用于存储结果。你不应该重用捐赠给计算的缓冲区,如果你尝试这样做,JAX 会抛出错误。请注意,donate_argnums 仅适用于位置参数,关键字参数不会被捐赠。有关缓冲区捐赠的更多详情,请参阅 FAQ。
- 返回:
fun
的并行版本,其参数对应于fun
的参数,但在in_axes
指示的位置有额外的数组轴,并且输出具有一个额外的领先数组轴(大小相同)。- 返回类型:
Any
例如,假设有8个XLA设备可用,
pmap()
可以作为一个沿前导数组轴的映射使用:>>> import jax.numpy as jnp >>> >>> out = pmap(lambda x: x ** 2)(jnp.arange(8)) >>> print(out) [0, 1, 4, 9, 16, 25, 36, 49]
当主导维度小于可用设备的数量时,JAX 将仅在一部分设备上运行:
>>> x = jnp.arange(3 * 2 * 2.).reshape((3, 2, 2)) >>> y = jnp.arange(3 * 2 * 2.).reshape((3, 2, 2)) ** 2 >>> out = pmap(jnp.dot)(x, y) >>> print(out) [[[ 4. 9.] [ 12. 29.]] [[ 244. 345.] [ 348. 493.]] [[ 1412. 1737.] [ 1740. 2141.]]]
如果你的主维度大于可用设备的数量,你将会遇到错误:
>>> pmap(lambda x: x ** 2)(jnp.arange(9)) ValueError: ... requires 9 replicas, but only 8 XLA devices are available
与
vmap()
类似,在in_axes
中使用None
表示某个参数没有额外的轴,应该在整个副本中进行广播,而不是映射:>>> x, y = jnp.arange(2.), 4. >>> out = pmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None))(x, y) >>> print(out) ([4., 5.], [8., 8.])
注意,
pmap()
总是返回在其前导轴上映射的值,相当于在vmap()
中使用out_axes=0
。除了表示纯地图外,
pmap()
还可以用于表示通过集体操作进行通信的并行单程序多数据(SPMD)程序。例如:>>> f = lambda x: x / jax.lax.psum(x, axis_name='i') >>> out = pmap(f, axis_name='i')(jnp.arange(4.)) >>> print(out) [ 0. 0.16666667 0.33333334 0.5 ] >>> print(out.sum()) 1.0
在这个例子中,
axis_name
是一个字符串,但它可以是任何定义了__hash__
和__eq__
的 Python 对象。参数
axis_name
用于pmap()
命名映射的轴,以便集体操作(如jax.lax.psum()
)可以引用它。轴名在嵌套的pmap()
函数中尤为重要,因为集体操作可以作用于不同的轴:>>> from functools import partial >>> import jax >>> >>> @partial(pmap, axis_name='rows') ... @partial(pmap, axis_name='cols') ... def normalize(x): ... row_normed = x / jax.lax.psum(x, 'rows') ... col_normed = x / jax.lax.psum(x, 'cols') ... doubly_normed = x / jax.lax.psum(x, ('rows', 'cols')) ... return row_normed, col_normed, doubly_normed >>> >>> x = jnp.arange(8.).reshape((4, 2)) >>> row_normed, col_normed, doubly_normed = normalize(x) >>> print(row_normed.sum(0)) [ 1. 1.] >>> print(col_normed.sum(1)) [ 1. 1. 1. 1.] >>> print(doubly_normed.sum((0, 1))) 1.0
在多进程平台上,集体操作会覆盖所有设备,包括其他进程上的设备。例如,假设以下代码在每个进程有4个XLA设备的两个进程上运行:
>>> f = lambda x: x + jax.lax.psum(x, axis_name='i') >>> data = jnp.arange(4) if jax.process_index() == 0 else jnp.arange(4, 8) >>> out = pmap(f, axis_name='i')(data) >>> print(out) [28 29 30 31] # on process 0 [32 33 34 35] # on process 1
每个进程传入一个不同的长度为4的数组,对应其4个本地设备,psum操作覆盖所有8个值。从概念上讲,这两个长度为4的数组可以被视为一个分片的长度为8的数组(在此示例中等同于jnp.arange(8)),该数组被映射,长度为8的映射轴被赋予名称’i’。然后,每个进程上的pmap调用返回相应的长度为4的输出分片。
devices
参数可以用来指定哪些设备用于运行并行计算。例如,再次假设一个进程有8个设备,以下代码定义了两个并行计算,一个运行在前六个设备上,另一个运行在剩下的两个设备上:>>> from functools import partial >>> @partial(pmap, axis_name='i', devices=jax.devices()[:6]) ... def f1(x): ... return x / jax.lax.psum(x, axis_name='i') >>> >>> @partial(pmap, axis_name='i', devices=jax.devices()[-2:]) ... def f2(x): ... return jax.lax.psum(x ** 2, axis_name='i') >>> >>> print(f1(jnp.arange(6.))) [0. 0.06666667 0.13333333 0.2 0.26666667 0.33333333] >>> print(f2(jnp.array([2., 3.]))) [ 13. 13.]