shmap (shard_map) 用于简单的按设备代码#

sholto@, sharadmv@, jekbradbury@, zhangqiaorjc@, mattjj@

2023年1月

动机#

JAX 支持两种多设备编程的思想:

  1. 编译器,接管方向盘! 让编译器自动在设备上分区批量数组函数。

  2. 让我直接表达我的意思,该死! 给我每个设备的代码和明确的通信集合。

我们需要为两者都提供优秀的API,并且它们不应是互斥的替代方案,而是需要相互组合。

使用 pjit(现在只是 jit),我们为第一所学校提供了一个下一代API。但我们还没有完全升级第二所学校。pmap遵循第二所学校,但随着时间的推移,我们发现它有致命缺陷xmap解决了这些缺陷,但它并没有完全给我们提供每个设备的形状,而且还包含了一些其他的大想法。同时,新的需求出现了,比如在高效扩展Transformer推理中,需要每个设备的显式集体编程。

我们可以用 shmap 升级第二所学校。shmap 是:

  • 一个简单的多设备并行API,允许我们编写每个设备的代码,并明确使用集合操作,其中逻辑形状与每个设备的物理缓冲区形状匹配,并且集合操作完全对应于跨设备的通信。

  • xmap 的一个特化版本,具有缩减的功能和一些调整;

  • XLA SPMD 分区器的 ‘手动’ 模式的相当直接的表面化;

  • 一个有趣且容易发音的Seussian名字,可以代表 shard_mapshpecialized_xmapsholto_mapsharad_map

对于 pjit 用户shmap 是一个补充工具。它可以在 pjit 计算中使用,暂时进入“手动集体”模式,就像从编译器的自动分区中逃脱一样。这样,用户可以在大部分代码中获得 pjit 的便利性和熟悉的仅 NumPy 编程模型,同时还能在需要的地方使用 shmap 手动优化集体通信。这是两全其美!

对于 pmap 用户shmap 是一个严格的升级。它在表达能力、性能以及与其他 JAX API 的组合性方面更加出色,同时不会使基本的批量数据并行变得更加困难。

关于实际使用的更多信息,您可以跳转到 何时应该使用 shmap 以及何时应该使用 pjit。如果您想知道为什么我们需要一个新东西,或者 pmap 存在什么问题,请跳转到 为什么 pmapxmap 不能解决这个问题?。或者继续阅读下一节,查看一些 shmap 示例和 API 规范。

那么,让我们看看 shmap#

TL;DR 示例(随后有更详细的解释)#

Sho shick:

from functools import partial

import numpy as np

import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental import mesh_utils
from jax.experimental.shard_map import shard_map

devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, axis_names=('i', 'j'))

a = jnp.arange( 8 * 16.).reshape(8, 16)
b = jnp.arange(16 * 32.).reshape(16, 32)

@partial(shard_map, mesh=mesh, in_specs=(P('i', 'j'), P('j', None)),
         out_specs=P('i', None))
def matmul_basic(a_block, b_block):
  # a_block: f32[2, 8]
  # b_block: f32[8, 32]
  z_partialsum = jnp.dot(a_block, b_block)
  z_block = jax.lax.psum(z_partialsum, 'j')
  return z_block

c = matmul_basic(a, b)  # c: f32[8, 32]

注意:

  • 不需要嵌套(或 axis_index_groups)来处理多个并行轴,不同于 pmap

  • pmap 和 hard-xmap 不同,调用者中没有重塑操作,并且逻辑形状对应于每个设备的物理形状,这与(非硬)xmap 不同;

  • 通过使用 mesh 实现精确的设备布局控制,不同于 pmap

  • 逻辑和物理轴名只有一组,不像 xmap

  • 结果是一个 jax.Array,它可以高效地传递给 pjit,不像 pmap

  • 这段相同的代码在 pjit/jit 内部高效运行,不同于 pmap

  • 这段代码是急切执行的,因此我们可以在中间使用 pdb 并打印值,这与 xmap 当前的实现不同(尽管按照设计,xmap 在没有顺序调度的情况下原则上也可以急切执行)。

这里有一个完全分片结果的矩阵乘法变体:

@partial(shard_map, mesh=mesh, in_specs=(P('i', 'j'), P('j', None)),
         out_specs=P('i', 'j'))
def matmul_reduce_scatter(a_block, b_block):
  # c_partialsum: f32[8/X, 32]
  c_partialsum = jnp.matmul(a_block, b_block)
  # c_block: f32[8/X, 32/Y]
  c_block = jax.lax.psum_scatter(c_partialsum, 'j', scatter_dimension=1, tiled=True)
  return c_block

c = matmul_reduce_scatter(a, b)

慢下来,从基础开始!#

数组轴上的秩减少映射与秩保持映射#

我们可以将 pmap(以及 vmapxmap)视为沿着一个轴对每个数组输入进行解堆叠(例如,将一个2D矩阵解包为其1D行),将其主体函数应用于每个部分,并将结果重新堆叠在一起,至少在不涉及集体操作时:

pmap(f, in_axes=[0], out_axes=0)(xs) == jnp.stack([f(x) for x in xs])

例如,如果 xs 的形状是 f32[8,5],那么每个 x 的形状是 f32[5],并且如果每个 f(x) 的形状是 f32[3,7],那么最终堆叠的结果 pmap(f)(xs) 的形状是 f32[8,3,7]。也就是说,函数体 f 的每次应用都将比 pmap(f) 的相应参数少一个轴的输入作为参数。我们可以说这些是带有输入/输出解堆叠/堆叠的 降秩映射

f 的逻辑应用次数由被映射的输入轴的大小决定:例如,如果我们映射一个大小为8的输入轴,从语义上我们得到8次函数的逻辑应用,这些应用对于 pmap 总是对应于8个物理设备进行计算。

相比之下,shmap 没有这种降低秩的行为。相反,我们可以将其视为沿输入轴切片(或“解连接”)为块,应用主体函数,然后将结果连接回一起(同样,当不涉及集体操作时):

devices = np.array(jax.devices()[:4])
m = Mesh(devices, ('i',))  # mesh.shape['i'] = 4

shard_map(f, m, in_specs=P('i'), out_specs=P('i'))(y)
==
jnp.concatenate([f(y_blk) for y_blk in jnp.split(y, 4)])

回想一下,jnp.split 将其输入分割成相同秩的等大小块,因此在上面的例子中,如果 y 的形状是 f32[8,5],那么每个 y_blk 的形状是 f32[2,5],如果每个 f(y_blk) 的形状是 f32[3,7],那么最终连接的结果 shard_map(f, ...)(y) 的形状是 f32[12,7]。所以 shmap (shard_map) 映射在其输入的分片或块上。我们可以说它是一个 保持秩的映射,具有对其输入/输出的解连接/连接操作。

f 的逻辑应用数量由网格大小决定,而不是由任何输入轴大小决定:例如,如果我们有一个总大小为4的网格(即在4个设备上),那么从语义上我们得到函数的4个逻辑应用,对应于4个物理计算它们的设备。

控制每个输入如何被分割(未连接)和用 in_specs 平铺#

每个 in_specs 通过使用 PartitionSpec 按名称识别输入数组的一些轴与网格轴的对应关系,表示如何将该输入分割(或解连接)为应用主体函数的块。这种识别决定了分片大小;当输入轴与网格轴对应时,输入沿该逻辑轴被分割(解连接)为与相应网格轴大小相等的若干部分。(如果相应的网格轴大小不能均匀地分割输入数组轴大小,则会出错。)如果输入的 pspec 没有提到网格轴名称,则不会在该网格轴上进行分割。例如:

devices = np.array(jax.devices())
m = Mesh(devices.reshape(4, 2), ('i', 'j'))

@partial(shard_map, mesh=m, in_specs=P('i', None), out_specs=P('i', 'j'))
def f1(x_block):
  print(x_block.shape)
  return x_block

x1 = np.arange(12 * 12).reshape(12, 12)
y = f1(x1)  # prints (3,12)

在这里,因为输入的 pspec 没有提到网格轴名称 'j',所以没有输入数组的轴被分割到该网格轴上;同样,因为输入数组的第二个轴没有与(因此也没有被分割到)任何网格轴关联,所以应用 f1 时,在该轴上可以看到输入的完整视图。

当输入的pspec中没有提到网格轴时,我们总是可以重写为一个效率较低的程序,其中所有网格轴都被提及,但调用者执行 jnp.tile,例如:

@partial(shard_map, mesh=m, in_specs=P('i', 'j'), out_specs=P('i', 'j'))
def f2(x_block):
  print(x_block.shape)
  return x_block

x = np.arange(12 * 12).reshape(12, 12)
x_ = jnp.tile(x, (1, mesh.axis_size['j']))  # x_ has shape (12, 24)
y = f2(x_)  # prints (3,12), and f1(x) == f2(x_)

换句话说,因为每个输入的 pspec 可以提到每个网格轴名称零次或一次,而不是必须提到每个名称恰好一次,我们可以说除了 jnp.split 内置于其输入之外,shard_map 还内置了一个 jnp.tile,至少在逻辑上是这样(尽管根据参数的物理分片布局,可能不需要实际执行平铺)。要使用的平铺不是唯一的;我们也可以沿着第一个轴平铺,并使用 pspec P(('j', 'i'), None)

物理数据移动在输入上是可能的,因为每个设备都需要有一份适当的数据副本。

通过 out_specs 控制每个输出的组装方式,包括连接、块转置和解块。#

类似于输入端,每个 out_specs 通过名称识别输出数组的一些轴与网格轴的对应关系,表示如何将输出块(每个体函数应用一个,或等效地每个物理设备一个)重新组合以形成最终的输出值。例如,在上述 f1f2 示例中,out_specs 指示我们应该沿着两个轴将块结果连接在一起,从而在两种情况下形成形状为 (12,24) 的数组 y。(如果体函数的输出形状,即输出块的形状,其秩对于相应的输出 pspec 描述的连接来说太小,则会出现错误。)

当网格轴名称未在输出 pspec 中提及,它表示一个 非平铺:当用户编写一个未提及某个网格轴名称的输出 pspec 时,他们承诺输出块在该网格轴上是相等的,因此在该轴上只使用一个块(而不是将该轴上的所有块连接在一起)。例如,使用与上面相同的网格:

x = jnp.array([[3.]])

z = shard_map(lambda: x, mesh=m, in_specs=(), out_specs=P('i', 'j'))()
print(z)  # prints the same as jnp.tile(x, (4, 2))

z = shard_map(lambda: x, mesh=m, in_specs=(), out_specs=P('i', None))()
print(z)  # prints the same as jnp.tile(x, (4, 1)), or just jnp.tile(x, (4,))

z = shard_map(lambda: x, mesh=m, in_specs=(), out_specs=P(None, None))()
print(z)  # prints the same as jnp.tile(x, (1, 1)), or just x

注意,主体函数关闭一个数组值等同于将其作为参数传递,并带有相应的输入 pspec P(None, None)。作为另一个例子,更接近于上述其他例子:

@partial(shard_map, mesh=m, in_specs=P('i', 'j'), out_specs=P('i', None))
def f3(x_block):
  return jax.lax.psum(x_block, 'j')

x = np.arange(12 * 12).reshape(12, 12)
y3 = f3(x)
print(y3.shape)  # (12,6)

请注意,结果的第二个轴大小为6,是输入的第二个轴大小的一半。在这种情况下,由于集体 psum 的存在,未提及网格轴名称 'j' 在输出 pspec 中表示的未平铺是安全的,这确保了每个输出块在相应的网格轴上是相等的。以下是另外两个例子,我们改变在输出 pspec 中提及的网格轴:

@partial(shard_map, mesh=m, in_specs=P('i', 'j'), out_specs=P(None, 'j'))
def f4(x_block):
  return jax.lax.psum(x_block, 'i')

x = np.arange(12 * 12).reshape(12, 12)
y4 = f4(x)
print(y4.shape)  # (3,12)


@partial(shard_map, mesh=m, in_specs=P('i', 'j'), out_specs=P(None, None))
def f5(x_block):
  return jax.lax.psum(x_block, ('i', 'j'))

y5 = f5(x)
print(y5.shape)  # (3,6)

在物理方面,如果在输出pspec中不提及网格轴名称,则会从输出设备缓冲区中组装一个 Array,并在该网格轴上沿复制布局。

在运行时没有检查输出块实际上是否沿着网格轴相等,以便可以沿着该轴进行解块,或者等效地检查相应的物理缓冲区是否具有相等的值,从而可以解释为单个逻辑数组的复制布局。但是,我们可以提供一个静态检查机制,该机制会在所有可能不正确的程序上引发错误。

因为 out_specs 可以提到网格轴名称零次或一次,并且因为它们可以按任何顺序提及,我们可以说,除了其输出中内置的 jnp.concatenate 之外,shard_map 的输出还内置了无平铺和块转置。

在输出上不可能进行物理数据移动,无论输出 pspec 如何。相反,out_specs 仅编码如何将块输出组装成 Array,或者物理上如何跨设备解释缓冲区,以形成单个逻辑 Array 的物理布局。

API 规范#

from jax.sharding import Mesh
Specs = PyTree[PartitionSpec]

def shard_map(f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs
          ) -> Callable:
  ...

哪里:

  • mesh 编码了以数组排列的设备,并带有相关的轴名称,就像它对 xmapsharding.NamedSharding 所做的那样;

  • in_specsout_specsPartitionSpec,它们可以 仿射地 提及来自 mesh 的轴名称(不像 xmap 中那样使用独立的逻辑名称),分别表示输入和输出的切片/解联和联接(不像 pmapxmap 那样进行解堆叠和堆叠),未提及的名称分别对应于复制和去平铺(断言-复制-所以-给我-一个-副本)。

  • 传递给 f 的参数的形状与传递给 shard_map-of-f 的参数的形状具有相同的秩(与 pmapxmap 不同,后者的秩会减少),并且 f 的参数的形状是根据 shard_map-of-f 的相应参数的形状 shape 和相应的 PartitionSpec spec 计算得出的,大致为 tuple(sz // (1 if n is None else mesh.shape[n]) for sz, n in zip(shape, spec))

  • f 的主体可以使用 mesh 中的名称应用集体操作。

shmap 默认是急切的,这意味着我们逐个原语地分派计算,以便用户可以在完全复制的值上使用 Python 控制流和交互式 pdb 调试来打印任何值。要分阶段输出并端到端编译一个 shmap 函数,只需在其周围放置一个 jit。一个结果是 shmap 不像 xmappmap 那样有自己的分派和编译路径;它只是 jit 路径。

当它被例如一个封闭的 jit 分阶段处理时,shmap 降低到 StableHLO 是微不足道的:它只是在输入上切换到 ‘手动 SPMD 模式’,并在输出上切换回来。(我们目前不计划支持部分手动部分自动的模式。)

与效果的交互与 pmap 相同。

与自动微分的交互也与 pmap 类似(而不是尝试 xmap 引入的新语义,对应于具有未映射的中间结果,因此 gradreduce_axes 以及使 psum 转置为 pbroadcast 而不是 psum)。但它因此继承了 pmap 的一个未解决问题:在某些情况下,与其将 psum 转置为 psum,从而执行与前向传递 psum 对应的反向传递 psum,不如将反向传递 psum 移动到反向传递的其他位置,利用线性特性。许多高级 pmap 用户通过使用 custom_vjp 来实现 psum_idrevid_psumrev 函数来解决这一挑战,但由于很容易意外地使这些不平衡,这种技术是一个潜在的危险。我们有一些想法如何以更安全的方式提供这种功能。

何时应该使用 shmap 以及何时应该使用 pjit#

一种理念是:几乎总是更简单地用 jit==pjit 编写程序 — 但如果程序的某部分编译器优化得不如预期,就切换到 shmap

一个真实的变压器示例#

事实上,我们可以使用shmap和30行Python代码实现XLA中最近引入的”集体矩阵乘法”算法的简单版本,以重叠通信和计算。该算法的基本思想可以通过一个简单的例子来理解。

假设我们要计算 C = A @ B,其中 A 在第0维度上通过1D网格进行分片,而 BC 是复制的。

M, K, N = 4096, 2048, 1024
A = jnp.arange(np.prod((M, K))).reshape((M, K))
B = jnp.arange(np.prod((K, N))).reshape((K, N))

mesh = Mesh(np.array(jax.devices()), axis_names=('i'))
A_x = jax.device_put(A, NamedSharding(mesh, P('i', None)))

@jax.jit
def f(lhs, rhs):
  return lhs @ rhs

C = f(A_x, B)

一个配置文件显示了在matmul开始之前,跨8个设备的阻塞式all-gather。这是次优的,因为A在非收缩维度上被分片,并且A的每个分片可以独立地与B进行matmul运算,这种分块计算可以与从另一个设备获取A的下一个分片重叠进行。

image

这种重叠可以通过使用 shmap 和显式集合来实现。

def collective_matmul_allgather_lhs_non_contracting(lhs, rhs):
  # lhs is the looped operand; rhs is the local operand
  axis_size = jax.lax.psum(1, axis_name='i')
  axis_index = jax.lax.axis_index(axis_name='i')
  chunk_size = lhs.shape[0]

  def f(i, carrys):
    accum, lhs = carrys
    # matmul for a chunk
    update = lhs @ rhs
    # circular shift to the left
    lhs = jax.lax.ppermute(
        lhs,
        axis_name='i',
        perm=[(j, (j - 1) % axis_size) for j in range(axis_size)]
    )
    # device 0 computes chunks 0, 1, ...
    # device 1 computes chunks 1, 2, ...
    update_index = (((axis_index + i) % axis_size) * chunk_size, 0)
    accum = jax.lax.dynamic_update_slice(accum, update, update_index)
    return accum, lhs

  accum = jnp.zeros((lhs.shape[0] * axis_size, rhs.shape[1]), dtype=lhs.dtype)
  # fori_loop cause a crash: hlo_sharding.cc:817 Check failed: !IsManual()
  # accum, lhs = jax.lax.fori_loop(0, axis_size - 1, f, (accum, lhs))
  for i in range(0, axis_size - 1):
    accum, lhs = f(i, (accum, lhs))

  # compute the last chunk, without the ppermute
  update = lhs @ rhs
  i = axis_size - 1
  update_index = (((axis_index + i) % axis_size) * chunk_size, 0)
  accum = jax.lax.dynamic_update_slice(accum, update, update_index)

  return accum
jit_sharded_f = jax.jit(shard_map(
  collective_matmul_allgather_lhs_non_contracting, mesh,
  in_specs=(P('i', None), P()), out_specs=P()))
C = jit_sharded_f(A_x, B)

一个分析显示,all-gather 已经消失,取而代之的是与异步集体置换重叠的矩阵乘法。这个分析与集体矩阵乘法论文的结果非常吻合。

image

这种集体矩阵乘法技术可以用来加速变压器层中的前馈块。这通常包括两个矩阵乘法,随后是一个 ReduceScatter(用于解决从并行矩阵乘法中得到的局部和),以及前面是一个 AllGather(用于收集某些轴上的分片维度并允许局部和计算)。总的来说,一层的 ReduceScatter 和下一层的 AllGather 相当于一个 AllReduce

在一个典型的配置文件中,两个矩阵乘法之后会跟随一个 AllReduce,并且它们不会重叠。集体矩阵乘法可以用来实现重叠,但触发困难,有最小切片大小,并且尚未覆盖所有拓扑结构、张量形状和集体矩阵乘法的变体(即延迟和吞吐量优化的变体)。在最近的一篇论文中,我们发现通过手动在 shmap 风格中实现集体矩阵乘法变体,在许多情况下可以获得约40%的增益。

但这并不总是更复杂!我们期望这是一种更自然的思考管道计算的方式,并计划很快展示一些相关的演示!

另一个现实世界的例子#

以下是 shmap 在具有二维权重收集模式的变换器层传递中可能的外观(论文,第5页的第3.2.3节):

def matmul_2D_wg_manual(xnorm, q_wi, layer):
  '''Calls a custom manual implementation of matmul_reducescatter'''
  # [batch, maxlen, embed.X] @ [heads.YZ, embed.X, q_wi_per_head]
  # -> (matmul)
  # -> [batch, maxlen, heads.YZ, q_wi_per_head]{x unreduced}
  # -> (reducescatter over x into X heads, B batches)
  # -> [batch, maxlen, heads.YZX, q_wi_per_head]
  with jax.named_scope('q_wi'):
    xnorm = intermediate_dtype(xnorm)
    q_wi = matmul_reducescatter(
        'bte,hed->bthd',
        xnorm,
        params.q_wi,
        scatter_dimension=(0, 2),
        axis_name='i',
        layer=layer)
   return q_wi


import partitioning.logical_to_physical as l2phys

def pjit_transformer_layer(
    hparams: HParams, layer: int, params: weights.Layer, sin: jnp.ndarray,
    cos: jnp.ndarray, kv_caches: Sequence[attention.KVCache],
    x: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
  """Forward pass through a single layer, returning output, K, V."""

  def my_layer(t, axis=0):
    """Gets the parameters corresponding to a given layer."""
    return lax.dynamic_index_in_dim(t, layer, axis=axis, keepdims=False)

  # 2D: [batch.Z, time, embed.XY]
  x = _with_sharding_constraint(
      x, ('residual_batch', 'residual_time', 'residual_embed'))
  xnorm = _layernorm(x)
  # 2D: [batch, time, embed.X]
  xnorm = _with_sharding_constraint(
      xnorm, ('post_norm_batch', 'time', 'post_norm_embed'))
  # jump into manual mode where you want to optimise
  if manual:
    q_wi = shard_map(matmul_2D_wg_manual, mesh
                in_specs=(l2phys('post_norm_batch', 'time', 'post_norm_embed'),
                          l2phys('layers', 'heads', 'embed', 'q_wi_per_head')),
                out_specs=l2phys('post_norm_batch', 'time', 'heads', 'q_wi_per_head'))(xnorm, q_wi, layer)
  else:
    q_wi = jnp.einsum('bte,hed->bthd', xnorm, my_layer(params.q_wi))
    # 2D: [batch, time, heads.YZX, None]
    q_wi = _with_sharding_constraint(q_wi,
                                   ('post_norm_batch', 'time', 'heads', 'qkv'))
  q = q_wi[:, :, :, :hparams.qkv]
  q = _rope(sin, cos, q)
  # unlike in https://arxiv.org/pdf/2002.05202.pdf, PaLM implements
  # swiGLU with full d_ff dimension, rather than 2/3 scaled
  wi0 = q_wi[:, :, :, hparams.qkv:hparams.qkv + (hparams.ff // hparams.heads)]
  wi1 = q_wi[:, :, :, hparams.qkv + (hparams.ff // hparams.heads):]
  kv = jnp.einsum('bte,ezd->btzd', xnorm, my_layer(params.kv))
  k = kv[:, :, 0, :hparams.qkv]
  v = kv[:, :, 0, hparams.qkv:]
  k = _rope(sin, cos, k)

  y_att = jnp.bfloat16(attention.attend(q, k, v, kv_caches, layer))

  y_mlp = special2.swish2(wi0) * wi1
  # 2D: [batch, time, heads.YZX, None]
  y_mlp = _with_sharding_constraint(y_mlp,
                                    ('post_norm_batch', 'time', 'heads', None))

  y_fused = jnp.concatenate([y_att, y_mlp], axis=-1)
  # do the second half of the mlp and the self-attn projection in parallel
  y_out = jnp.einsum('bthd,hde->bte', y_fused, my_layer(params.o_wo))
  # 2D: [batch.Z, time, embed.XY]
  y_out = _with_sharding_constraint(
      y_out, ('residual_batch', 'residual_time', 'residual_embed'))
  z = y_out + x
  z = _with_sharding_constraint(
      z, ('residual_batch', 'residual_time', 'residual_embed'))
  return z, k, v

在下面的配置文件中,第一个和第二个 matmul 都被手动降低的版本所替代,其中计算(融合)与通信(ppermute)完全重叠!一个有趣的提示是,我们使用的是延迟优化的变体,即 ppmerute 像素是抖动的 — 因为在同一时间有两个使用相反 ICI 轴的 ppermute 重叠!

全对全(All-to-all)更难重叠,因此被搁置。

image

为什么 pmapxmap 还没有解决这个问题?#

pmap 是我们第一个多设备并行 API。它遵循每设备代码和显式集合流派。但它有重大缺陷,使其不适合当今的程序:

  • 映射多个轴需要嵌套的 pmap 嵌套的 pmap 不仅写起来繁琐,而且它们使得控制(甚至预测)数据和计算的设备放置变得困难,并且难以保持数据分片(见接下来的两个要点)。今天的程序需要多轴并行。

  • 控制设备布局是不可能的。 特别是在多轴并行的情况下,程序员需要控制这些轴如何与硬件资源及其通信拓扑对齐。但是(嵌套的)pmap 不提供对映射程序实例在硬件上如何放置的控制;用户无法控制的是一个自动的设备顺序。(Gopher 使用 axis_index_groups 和一个单一的非嵌套 pmap 本质上是通过将多个并行轴扁平化为一个来绕过这一点。)

  • jit/pjit 组合性。 jit-of-pmap 是一个性能陷阱,嵌套 pmap 也是如此,例如 scan-of-pmap,因为在从内部 pmap 返回时,分片不会被保留。为了保留分片,我们需要对 jaxprs 进行模式匹配,以确保我们处理的是完美嵌套的 pmap,或者在 jit 内部只有一个 pmap。此外,pjit 在这里没有帮助,因为 pmap 针对 XLA 副本,而 pjit 针对 XLA SPMD 分区器,组合这两者很困难。

  • jax.Array 兼容性(因此 pjit 兼容性)。 由于 pmap 输出的分片无法表示为 Shardings / OpShardings,这是由于 pmap 的堆叠而非连接语义,pmap 计算的输出目前无法在不反弹到主机(或调度重塑计算)的情况下传递给 pjit 计算。

  • 多控制器语义(因此与 pjit 兼容)。 多控制器的 pmap 将值跨控制器连接,这虽然有效,但与单控制器的 pmap 的堆叠语义不同。更实际的是,它排除了使用非完全可寻址的 jax.Array 输入和输出,就像我们在多控制器 pjit 中使用的那样。

  • 急切模式。 我们没有首先让 pmap 急切执行,尽管我们最终(在4年多之后!)通过 disable_jit() 添加了急切操作,但 pmapjit 融合的事实意味着它有自己的编译和调度路径(实际上有两条调度路径:在 Python 中处理 Tracer,以及在 C++ 中处理原始 Array 输入以提高性能!),这是一个沉重的实现负担。

  • 调用者中需要的重塑。 一个典型的使用 pmap 在8个设备上的用例可能看起来像是从大小为128的批次轴开始,将其重塑为大小为(8, 16)的两个轴,然后对第一个轴进行 pmap 操作。这些重塑操作很尴尬,编译器通常会将它们解释为复制而不是视图 — 增加了内存和时间的使用。

这些缺点在只进行批量数据并行时并不算太糟。但当涉及更多并行性时,pmap 就无法胜任了!

xmap 作为 pmap 的下一代进化版,解决了(几乎)所有这些问题。shmap 跟随 xmap 的脚步,以本质上相同的方式解决了这些问题;实际上,shmap 就像是 xmap 的一个专门子集(有些人称之为“硬 xmap”子集),并进行了一些调整。

在初始原型中,我们选择将 shmap 作为与 xmap 分开的原语来实现,因为限制其支持的功能集使得更容易专注于核心功能。例如,shmap 不允许未映射的中间步骤,这使得不必担心命名轴和自动微分之间的交互。此外,不必考虑所有功能对之间的交互,使得更容易添加超出 xmap 当前实现的能力,例如对急切模式的支持。

shmapxmap 共享了降低代码的显著部分。我们未来可以考虑将两者合并,或者甚至只专注于 shmap,这取决于使用情况的发展。