Pallas中TPU的分布式计算#

在本教程中,我们将介绍在TPU上Pallas的分布式计算基础知识。我们将学习TPU拓扑结构、使用远程DMA原语进行通信,以及如何使用shard_map从JAX调用分布式内核。我们还将讨论一些更高级的内核编写技术,如双缓冲、双向带宽优化和嵌套流水线。作为教育示例,我们将学习如何从JAX实现各种集合原语,例如lax.ppermutelax.all_gatherlax.psumlax.psum_scatter

一些推荐的阅读材料:

import jax
from jax import lax
from jax import numpy as jnp
from jax.experimental import mesh_utils
from jax.experimental import pallas as pl
from jax.experimental import shard_map
from jax.experimental.pallas import tpu as pltpu

P = jax.sharding.PartitionSpec

num_devices = jax.local_device_count()
assert num_devices > 1, "Please run this notebook with more than one device."
assert "TPU" in jax.devices()[0].device_kind, "Please run this notebook with TPU devices."
print(f"Running with {num_devices} {jax.devices()[0].device_kind} devices.")
Running with 4 TPU v5 lite devices.

TPU 拓扑结构#

TPU 通常以多个设备的集群形式部署,这些设备通过高速的芯片间互连 (ICI) 进行通信,速度远快于典型的网络连接。例如,TPU v5p 的规格表中指出,每个芯片的 ICI 带宽为 4.8Tb/s(作为参考,TPU v5p 还具有 21Tb/s 的 本地 HBM 带宽)。ICI 使我们能够实现快速高效的分布式内核,这些内核在集群内部需要高带宽通信,并使用数据中心网络对带宽需求较低的操作进行并行化,例如在批量维度上的数据并行性。

TPU 集群通常以 ND 超环拓扑结构排列。以下图形给出了不同规模配置的几个例子。

tpu_topologies

将超环转化为图形,可以如下可视化。每条边(橙色或黑色)是两个设备之间的双向连接。您会经常听到与设备拓扑讨论相关的环——超环的一个关键特征是,当沿着集群的一个轴进行切片时,例如节点 [(0,1), (1, 1), (2, 1), (3, 1)][(0, 1), (1, 1)],我们会得到一个设备环。这是我们可以用来简化集群内部通信模式的一个特征。

tpu_torus

远程直接内存访问(RDMA)模型#

TPU 通过一种仅推送的模型进行通信,称为远程直接内存访问(RDMA)。TPU 被允许发出从本地缓冲区推送到同一 Pod 中另一个设备上的任何缓冲区的复制指令,该操作与主程序线程异步执行。然而,TPU 只能读取存储在本地的数据。这与更传统的多核编程形成对比,在传统模式下,可以对共享内存中的值进行读取和写入。

异步远程复制操作#

pltpu.make_async_remote_copy 函数用于创建一个远程 DMA 描述符对象,参数化 “发送” 操作和 “接收” 操作。它的函数签名如下:

 def make_async_remote_copy(
     src_ref: Ref,
     dst_ref: Ref,
     send_sem: Ref[SemaphoreType],
     recv_sem: Ref[SemaphoreType],
     device_id: int | tuple[int, ...],
     device_id_type: DeviceIdType
 ) -> AsyncCopyDescriptor:
  • src_ref 是包含希望发送到另一个设备 dst_ref 的本地 Ref(在任何内存空间中)。

  • dst_ref 是目标设备上的远程 Ref(在任何内存空间中),数据将复制到此处。

  • send_sem 是一个 DMA 信号量,用于阻止在所有数据从 src_ref 发送完成之前。

  • recv_sem 是一个 DMA 信号量,用于阻止在 dst_ref 收到预期字节数之前。DMA 的发送者将写入接收者的 recv_sem

  • device_id 是要发送数据的目标设备的设备 ID。

  • device_id_type 指定 device_id 的格式,可以是 LOGICAL 格式(整数设备 ID)或 MESH 格式(逻辑设备网格中的 ND-tuple 索引)。默认模式是 MESH。

make_async_remote_copy 返回一个描述符对象,您可以使用 .start() 方法启动 DMA,并使用 .wait_send() 阻塞在 send_sem 上,以及使用 .wait_recv() 阻塞在 recv_sem 上(或者使用 .wait() 在两者上阻塞)。如果某个设备仅预期发送数据,则只需调用 .start().wait_send(),如果设备仅接收数据,则只需调用 .wait_recv()。如果使用 SPMD 模式,其中所有设备执行 DMA,则每个设备通常会调用 .start().wait()

dma_descriptor = make_async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id)
dma_descriptor.start() # 启动 DMA(非阻塞)。
# ... 执行其他工作
dma_descriptor.wait_send() # 在所有数据发送完成之前阻塞。
dma_descriptor.wait_recv() # 在所有数据接收完成之前阻塞。

作为示例,我们来可视化一个 DMA,我们考虑 4 个设备(索引为 0、1、2、3)。我们考虑一种方案,其中设备 0 复制到设备 1,设备 2 和 3 相互复制。在实践中,我们可以通过使用 @pl.when 根据设备 ID 创建这样的不对称通信模式。

(1) 每个设备创建 DMA 描述符。设备 0、2 和 3 调用 .start()src_ref 启动 DMA。设备 1 跳过 .start(),不执行任何操作,例如,通过使用 pl.when

rdma_start

(2) 由于 .start() 是非阻塞的,每个设备在 DMA 进行时可以执行其他计算。设备 0、2 和 3 调用 .wait_send() 等待 send_sem,该方法阻止在所有数据发送完成之前。

rdma_send

(3) 最后,设备 1、2 和 3 将调用 .wait_recv() 等待 recv_sem,直到所有数据到达 dst_ref

rdma_recv

上述通信模式可以写成如下形式:

def example_kernel(input_ref, output_ref, send_sem, recv_sem):
    device_id = lax.axis_index('x')
    copy_0_to_1 = pltpu.make_async_remote_copy(
        src_ref=input_ref,
        dst_ref=output_ref,
        send_sem=send_sem,
        recv_sem=recv_sem,
        device_id=1,
    )
    copy_2_to_3 = pltpu.make_async_remote_copy(
        src_ref=input_ref,
        dst_ref=output_ref,
        send_sem=send_sem,
        recv_sem=recv_sem,
        device_id=3,
    )
    copy_3_to_2 = pltpu.make_async_remote_copy(
        src_ref=input_ref,
        dst_ref=output_ref,
        send_sem=send_sem,
        recv_sem=recv_sem,
        device_id=2,
    )
    @pl.when(device_id == 0)
    def _():
      copy_0_to_1.start()
      copy_0_to_1.wait_send()
    @pl.when(device_id == 1)
    def _():
      copy_0_to_1.wait_recv()
    @pl.when(device_id == 2)
    def _():
      copy_2_to_3.start()
      copy_2_to_3.wait_send()
      copy_3_to_2.wait_recv()
    @pl.when(device_id == 3)
    def _():
      copy_3_to_2.start()
      copy_3_to_2.wait_send()
      copy_2_to_3.wait_recv()

DMA 信号量#

send_semrecv_sem 是一种特殊类型的信号量的实例,仅用于 DMA。在指定输入规范时,必须使用 tpu.SemaphoreType.DMA 类型来分配它们给 pallas_call

从内部来看,DMA 信号量可以被视为整数值的进度跟踪器。在 DMA 启动时,本地设备将开始异步递增 send_sem 的值以及接收器的 recv_sem。等待信号量将阻塞,直到信号量的值达到发送/接收的总字节数;当达到该值时,等待线程被释放,信号量的值将减少相同的数量。这意味着要么所有数据已经发送(针对 send_sem),要么所有数据已经接收(针对 dst_sem)。信号量的值可以通过 pl.semaphore_read 读取,但请注意,该值的底层语义可能会在硬件代之间变化(例如,该值可能不准确反映发送的字节数,尽管这是在推理信号量行为时有用的思维模型)。

路由#

发送者被允许将数据发送到同一 Pod 内的任何接收者,即使它们没有共享的直接连接(TPU v5e 除外,设备只能路由到与自身的指数为 2 的偏移量)。TPU 具有一种内部路由机制,可以将数据传递给到达目标的下一台设备。然而,以这种方式进行通信并不推荐,因为作为内核编写者,您无法控制网络争用。我们将在本教程中讨论的示例通过仅将数据传递给相邻设备来最小化低效的通信。

失败模式#

如果错误使用远程 DMA,您可能会遇到一些失败模式,这些模式可能难以调试。错误的 DMA 用法的一般症状是崩溃、挂起或静默数据损坏:

  • 如果信号量以无效的非零值退出程序,Pallas 将崩溃并退出程序。

  • 如果等待信号量但接收到的字节数不足(即没有发送者,或者发送的数据小于接收设备上 dst_ref 的大小),程序可能会无限期挂起,等待从未发送的字节。在这种情况下,程序需要重新启动。

  • 如果遇到竞争条件,可能会发生静默数据损坏,如果同时发生两个写操作或同时读取和写入。

上述问题的一些常见原因包括:

  • 如果某个设备调用 .wait_recv() 但没有其他设备发送数据给它,内核可能会挂起。

  • 如果发送给某个设备的字节数超过其预期接收的字节数,它可能也会因非零信号量状态而崩溃。如果发送的字节数不足,它可能会无限期挂起。

  • 如果启动 DMA 但未等待信号量,程序可能会因非零信号量状态而崩溃。

  • 如果两个设备同时复制到同一个目的地,您可能会遇到由于竞争条件而导致的非确定性结果,或者由于非零信号量状态而崩溃。

示例:右置换 (lax.ppermute)#

让我们深入一个非常基本的示例。我们将实现一个执行右置换的内核,其中每个设备将其数据切片发送到它右边的邻居。

假设我们有一个包含512个元素的数组,我们将其分割成4个设备,每个设备的切片大小为128。每个设备将把其切片传递给下一个设备,输出将由相同的数据组成,但切片会旋转1个位置。这与将置换设置为(n, (n+1) % 4)lax.ppermute操作是相同的。

为了以分布式模式调用内核,我们将pallas_call包装在shard_map转换中。从那里,我们可以以与编写普通单设备Pallas内核相同的方式编写内核,只是现在我们可以访问远程DMA指令。JAX汇聚原语,如lax.axis_index,可用于获取可以用于计算要复制到哪些目标设备的device_id,方法是引用传递给shard_map的相同命名轴名称。

partition = P(None, 'x')
devices = mesh_utils.create_device_mesh((1, num_devices))
mesh = jax.sharding.Mesh(devices, partition)
sharding = jax.sharding.NamedSharding(mesh, partition)

# 创建一个输入数组,该数组将最后一个维度分片
# 所有设备。
input_arr = jax.random.uniform(jax.random.key(0), (8, 128 * num_devices))
input_arr = jax.device_put(input_arr, sharding)


def right_permute_kernel(input_ref, output_ref, send_sem, recv_sem):
  my_id = lax.axis_index('x')
  right_neighbor = lax.rem(my_id + 1, num_devices)
  remote_copy_op = pltpu.make_async_remote_copy(
      src_ref=input_ref,
      dst_ref=output_ref,
      send_sem=send_sem,
      recv_sem=recv_sem,
      device_id=(0, right_neighbor),
      device_id_type=pltpu.DeviceIdType.MESH,
  )
  remote_copy_op.start()
  remote_copy_op.wait()


out_shape = jax.ShapeDtypeStruct((8, 128), jnp.float32)
grid_spec = pltpu.PrefetchScalarGridSpec(
    num_scalar_prefetch=0,
    # TPUMemorySpace.ANY 通常会将张量放置在HBM中。
    in_specs=[
        pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
    ],
    out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
    scratch_shapes=(
        # 我们在暂存内存中分配DMA信号量。
        [pltpu.SemaphoreType.DMA] * 2
    ),
)
right_permute = pl.pallas_call(
    right_permute_kernel,
    out_shape=out_shape,
    grid_spec=grid_spec,
)
# 将内核包裹在shard_map中进行调用。
pallas_result = jax.jit(
    shard_map.shard_map(
        right_permute,
        mesh=mesh,
        in_specs=partition,
        out_specs=partition,
        check_rep=False,
    )
)(input_arr)

# 将Pallas结果与XLA shard_map结果进行对比。
perm = tuple((src, (src + 1) % num_devices) for src in range(num_devices))

xla_result = jax.jit(
    shard_map.shard_map(
        lambda x: lax.ppermute(x, 'x', perm),
        mesh=mesh, in_specs=partition, out_specs=partition)
)(input_arr)

print('Input = ', input_arr[0, ::128])
print('Pallas Result = ', pallas_result[0, ::128])
print('lax.ppermute Result = ', xla_result[0, ::128])
print(
    'Difference |Pallas - lax.ppermute| = ',
    jnp.mean(jnp.abs(pallas_result - xla_result)),
)
Input =  [0.9858954  0.11763906 0.9955574  0.775211  ]
Pallas Result =  [0.775211   0.9858954  0.11763906 0.9955574 ]
lax.ppermute Result =  [0.775211   0.9858954  0.11763906 0.9955574 ]
Difference |Pallas - lax.ppermute| =  0.0

示例:全收集(lax.all_gather#

在下一个示例中,我们将实现全收集集体操作,在JAX中对应于lax.all_gather。与上面提到的仅涉及一对源和目标邻居的右置换示例不同,全收集操作需要所有设备之间的通信,因此我们必须考虑数据如何在它们之间路由。我们实施这一点的具体方式由设备拓扑决定,我们假设这是一个环。

环形通信模式#

我们将编写我们的内核,假设有一个环形拓扑。环形结构非常适合TPU,因为沿着任意维度切片一个环面会生成一个环。当编写集体操作时,我们通常只需考虑环面的一维切片,因为环面的不同维度保留给不同类型的并行性(例如数据与模型)。

我们将采用的策略是编写一个循环内核,在每次迭代中,一个设备从其左邻居接收一片分片数组,并将先前接收到的片段复制到其右邻居。经过num_devices次迭代,每个设备将在其本地HBM中拥有整个数组的副本。

all_gather

我们可以重新利用Pallas的grid参数来实现这个循环。与我们在之前的教程中对数组的块进行迭代不同,我们将网格设置为(num_devices,),以指示我们希望对设备的数量进行循环,并使用pl.program_id在Pallas内核内部获取循环迭代。以下代码片段演示了如何实现这一点:

partition = P('x', None)
devices = mesh_utils.create_device_mesh((num_devices, 1))
mesh = jax.sharding.Mesh(devices, partition)
sharding = jax.sharding.NamedSharding(mesh, partition)

# 创建一个输入数组,将第一个维度分片到
# 所有设备。
input_arr = jax.random.uniform(jax.random.key(0), (8 * num_devices, 128))
input_arr = jax.device_put(input_arr, sharding)


def all_gather_kernel(input_ref,
                      output_ref,
                      local_copy_sem,
                      send_sem,
                      recv_sems):
  outer_step = pl.program_id(0)
  my_id = lax.axis_index('x')
  right_neighbor = lax.rem(my_id + 1, num_devices)
  copy_slot = my_id - outer_step
  copy_slot = lax.rem(copy_slot + num_devices, num_devices)

  @pl.when(outer_step == 0)
  def _():
    local_copy_op = pltpu.make_async_copy(
      src_ref=input_ref,
      dst_ref=output_ref.at[my_id],
      sem=local_copy_sem,
    )
    local_copy_op.start()
    local_copy_op.wait()

  # 抄送我们的右邻。
  # 请注意,我们还将从左边的邻居接收数据,
  # 但要在 `copy_slot-1` 而不是 `copy_slot`!这是利用了
  # 这些指数不需要在远程DMA之间对称。
  remote_copy_op = pltpu.make_async_remote_copy(
      src_ref=output_ref.at[copy_slot],
      dst_ref=output_ref.at[copy_slot],
      send_sem=send_sem,
      recv_sem=recv_sems.at[outer_step],
      device_id=(right_neighbor, 0),
      device_id_type=pltpu.DeviceIdType.MESH,
  )
  remote_copy_op.start()
  remote_copy_op.wait()

out_shape = jax.ShapeDtypeStruct((num_devices, 8, 128), jnp.float32)
grid_spec = pltpu.PrefetchScalarGridSpec(
            num_scalar_prefetch=0,
            in_specs=[
                # TPUMemorySpace.ANY 通常会将张量放置在HBM中。
                pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
            ],
            out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
            scratch_shapes=(
              # DMA信号量分配在暂存内存中。
              # 我们为本地HBM-VMEM拷贝分配了一个信号量。
              # 并为远程发送信号量分配一个。
              [pltpu.SemaphoreType.DMA] * 2
              # 我们另外为每个设备分配一个接收信号量。
              # 这是为了避免出现我们拥有多个
              # 飞行中的DMA,因为我们不希望共享接收。
              # DMA之间的信号量。
              + [pltpu.SemaphoreType.DMA((num_devices-1,))]

            ),
            grid=(num_devices-1,)
        )

all_gather = pl.pallas_call(
      all_gather_kernel,
      out_shape=out_shape,
      grid_spec=grid_spec,
  )

# 将内核包裹在shard_map中进行调用。
pallas_result = jax.jit(
      shard_map.shard_map(
          all_gather,
          mesh=mesh,
          in_specs=partition,
          out_specs=partition,
          check_rep=False
      )
)(input_arr)

# 将Pallas结果与XLA shard_map结果进行对比。
xla_result = jax.jit(
    shard_map.shard_map(
        lambda x: lax.all_gather(x, 'x'),
        mesh=mesh, in_specs=partition, out_specs=partition
    )
)(input_arr)

print('Input: ', input_arr.shape, input_arr[::8, 0])
print('Pallas Result: ', pallas_result.shape, pallas_result[:, 0, 0])
print('lax.all_gather Result: ', xla_result.shape, xla_result[:, 0, 0])
print('Difference |Pallas - lax.all_gather| = ',
      jnp.mean(jnp.abs(pallas_result - xla_result)))
Input:  (32, 128) [0.9858954  0.54248166 0.9547038  0.954962  ]
Pallas Result:  (16, 8, 128) [0.9858954  0.54248166 0.9547038  0.954962   0.9858954  0.54248166
 0.9547038  0.954962   0.9858954  0.54248166 0.9547038  0.954962
 0.9858954  0.54248166 0.9547038  0.954962  ]
lax.all_gather Result:  (16, 8, 128) [0.9858954  0.54248166 0.9547038  0.954962   0.9858954  0.54248166
 0.9547038  0.954962   0.9858954  0.54248166 0.9547038  0.954962
 0.9858954  0.54248166 0.9547038  0.954962  ]
Difference |Pallas - lax.all_gather| =  0.0

这里值得一提的是使用多个接收信号量。因为我们只在接收设备上进行阻塞,所以发件方仍然可以在接收方完成处理第一个DMA之前发送多个DMA(参见下一个章节和讨论竞争条件的reduce-sum示例)。在这种情况下,我们可能会遇到同一个信号量被用于同时发生的多个DMA的情况。为避免这种情况,我们分配了 num_devices-1 个信号量,从而消除重复使用的风险。虽然这种竞争条件在如此小的内核中不太可能发生,但在较大的内核中,设备失去同步的可能性更大,并可能导致静默失败。

高级技术#

现在我们已经了解了如何使用远程DMA操作编写几个基本的内核,我们将探讨更高级的同步技术和编写高效内核的方法。

同步:常规信号量和屏障信号量#

我们在基础教程中实现的示例不需要特殊的同步处理,因为所有必要的通信都写入了不相交的缓冲区。然而,其他操作可能需要更复杂的通信模式,这需要额外的同步原语以避免竞争条件。Pallas 提供了两种额外的原语来帮助解决这个问题:常规信号量和屏障信号量。

常规信号量#

常规信号量是用于跨多个设备进行同步的标准工具。信号量从根本上来说是计数器——它们可以被任何设备递增,此后设备可以阻塞,直到信号量的值达到特定值(然后再将值递减)。

可以对常规信号量执行的三种主要操作是信号、等待和读取:

def semaphore_signal(
    sem: Ref[SemaphoreType],
    inc: int,
    device_id: int | tuple[int, ...],
    device_id_type: DeviceIdType
) -> None:
  ... # 将目标设备 `device_id` 上的信号量 `sem` 递增 `inc`。
  
def semaphore_wait(
    semaphore: Ref[SemaphoreType],
    value: int,
) -> None:
  ... # 阻塞直到本地分配的 `sem` 达到 `value`,然后递减 `value` 并继续。
    
def semaphore_read(
    sem: Ref[SemaphoreType],
) -> jax.Array:
  ...  # 返回 `sem` 的当前值作为 `int32[]`。

为了使用常规信号量,可以通过指定 pltpu.SemaphoreType.REGULAR 而不是 pltpu.SemaphoreType.DMA 以与 DMA 信号量相同的方式进行分配。

信号量在 Pallas 程序结束时必须为零才能成功完成。可能发生的两种错误情况是:

  • 如果信号量超量信号,则程序将在非零(>0)信号量的情况下结束。在这种情况下,程序将在完成时崩溃。这对于调试非常有用,因为非零信号量通常意味着程序内部存在错误。

  • 如果信号量超量等待,程序将在阻塞的 semaphore_wait 调用上挂起,同时等待信号量递增。在这种情况下,需要重新启动设备或程序。

屏障信号量#

屏障信号量是全局分配的信号量,用于在整个程序中同步设备,并确保所有设备都已进入 Pallas 内核。

如果在较大的 XLA 程序的上下文中执行 Pallas 内核,我们需要确保所有进行通信的设备都已进入该内核。然而,DMA 和常规信号量都是局部作用域的——它们仅被其他已经进入内核的设备理解。屏障信号量作为全局可理解的信号量,可以用于不同 XLA 程序中设备的同步。

默认情况下,如果不指定屏障信号量,Pallas 将在程序开始时自动插入一个屏障信号量。然而,编写自己的屏障信号量可能更高效。屏障信号量与常规信号量类似,都是计数器,可以通过 semaphore_signal 进行递增,并可以通过 semaphore_wait 进行递减。它们通过在内核中调用 get_barrier_semaphore() 创建。通常,我们在内核的开始处使用屏障来与所有我们进行通信的设备同步。

from jax.experimental.pallas import tpu as pltpu

def example_kernel(...):
  # 在内核开始时使用屏障信号量。
  # is_start_of_kernel = ...
  # right_neighbor = ...
  # ...
  @pl.when(is_start_of_kernel)
  def _():
    barrier_sem = pltpu.get_barrier_semaphore()
    # 递增你的右邻居的信号量。
    pltpu.semaphore_signal(
          barrier_sem,
          device_id=right_neighbor,
          device_id_type=pltpu.DeviceIdType.LOGICAL,
    )
    # 等待你的左邻居递增你的信号量
    pltpu.semaphore_wait(barrier_sem, 1)
  # ...

使用屏障信号量时,必须将 collective_id 编译器参数传递给 pallas_call 以指定使用哪个屏障信号量。一个 TPU 只有少量固定数量的屏障信号量(通常在 20-30 个数量级),因此应谨慎使用。为了确保正确性,只有共享相同通信模式的内核应使用相同的 collective_id。例如,如果两个内核仅与同一网格轴上的邻居进行同步,则可以共享相同的 collective_id。然而,如果两个内核在不同的轴上进行同步,则必须使用不同的 collective_id。如果不这样做,可能会导致难以调试的竞争条件。

kernel = pl.pallas_call(
      example_kernel,
      ...,
      compiler_params=pltpu.TPUCompilerParams(collective_id=0),
)

双缓冲#

为了避免从一个本地 Ref 读取数据时,另一个设备也在进行写入,从而产生竞争条件,有一种有用的技术是“双缓冲”策略,我们为每个目标值分配两个 Ref。在每次迭代中,一个 Ref 将被指定为“工作”槽,而另一个将被指定为“接收”槽。设备可以自由地使用工作槽进行计算,但只会将数据复制到其邻居的接收槽中。工作槽和接收槽在每次迭代中交替,因此一旦复制完成,旧的接收槽变成新的工作槽,反之亦然。通过适当使用此方案,数据永远不会从同一个缓冲区中读取和写入。

以下代码框架展示了如何使用双缓冲。我们在变量 iteration 中保留一个运行的迭代计数器,working_slotreceiving_slot 每次迭代在 0 和 1 之间交替。dst_ref 被分配为双缓冲,大小为 [2, ...]。在每次迭代中,我们使用 dst_ref.at[working_slot, ...] 从工作槽中读取数据,并使用该值进行计算。同时,我们复制到邻居的 dst_ref.at[receiving_slot] 以避免覆盖他们的 working_slot 值。通过这种方式构建通信,可以在最小化竞争条件风险的同时,比特尔重叠远程 DMA 的通信延迟与本地计算。

def kernel(...):
  # ...
  iteration = pl.program_id(0)
  working_slot = lax.rem(iteration, 2)
  receiving_slot = 1 - working_slot
  # ...

  local_copy_op = pltpu.make_async_copy(
    src_ref=dst_ref.at[working_slot, ...],
    dst_ref=local_scratch_ref,
    sem=local_copy_sem,
  )
  local_copy_op.start()
  remote_copy_op = pltpu.make_async_remote_copy(
    src_ref=src_ref,
    dst_ref=dst_ref.at[receiving_slot, ...],
    send_sem=send_sem,
    recv_sem=recv_sem,
    device_id=target_device,
    device_id_type=pltpu.DeviceIdType.MESH,
  )
  remote_copy_op.start()
  
  local_copy_op.wait()
  # ... 在等待 async_copy_op 完成时对 local_scratch 进行工作。
  remote_copy_op.wait()

就同步而言,双缓冲结构在所有设备执行相同迭代时有效。如果发送方成功比接收方提前一轮迭代,则其 working_slotreceiving_slot 索引将与接收方翻转,这意味着它可能在接收方从中读取的同时向 working_slot 写入数据。为了避免这种情况,可能需要使用信号量来同步发送方与接收方,或添加额外的缓冲槽(“三重”、“四重”或 N 缓冲)以允许在占用更多内存的情况下进行额外的提前运行。在我们之前的 all_gather 示例中,注意到内核中包含具有 N 个槽的接收缓冲,完全避免了竞争条件。在我们的下一个内核中,我们将通过一个使用显式同步的双缓冲示例。

示例:全规约求和(lax.psum#

我们将使用双缓冲和信号量进行同步,现在实现一个全规约求和内核。对于熟悉JAX中的集体操作的人来说,等效的操作是lax.psum。全规约是一个标准的集体操作,其目标是在数组的一个轴上进行规约,但该数组在多个设备之间进行分片。

reduce_sum_1

在上面的例子中,数组[5, 2, 1, 3]在4个设备之间进行了分片。全规约求和操作将对所有值求和,并将结果在每个设备上复制,最终导致结果为[11, 11, 11, 11],在所有4个设备之间进行了分片。

全规约的简单实现将是将所有所需值收集到每个设备上,然后进行规约。然而,我们可以通过将通信与计算交错来提高该实现的性能。交错的单向全规约可以如下可视化。在每次迭代中,我们从左邻居接收一个输入值,并同时将输入传递给下一个邻居,同时用我们的本地累加器对其进行增量处理。在N-1次迭代后,每个设备的内存中将有一个完整和的副本。

reduce_sum_2

整合所有内容#

以下内核演示了如何将这些原理结合成一个功能性内核。

outer_step==0时,序言首先与两个邻居发起屏障,以确保它们也已进入内核。同时,它还处理所有Ref的初始化,并处理第一次向右邻居的“工作”槽的远程复制。

主程序假设值已从先前的迭代或序言中复制到我们的本地工作槽中。一个复杂的因素是我们的目标缓冲区存在于HBM中,但在执行算术操作之前,我们需要将值加载到VMEM中。因此,我们同时将工作槽的值复制到我们的VMEM(receive_scratch)中,并将值传递给右邻居的接收槽。一旦值被复制到我们的VMEM中,我们就可以将其累积到我们的结果中(包含在o_ref中)。

如果一个设备的循环进度超前于它的右邻居,就可能发生微妙的竞争条件。在这种情况下,它可能会在接收方读取时同时复制到接收者的working_slot中。为避免这种情况,每个设备将在复制到右邻居的dst_ref之前,会在REGULAR信号量上阻塞,直到它发出信号表示已经完成从其working_slot的读取。对于像这个例子这样的小内核,这种竞争条件很少被触发,但如果使用pltpu.delay指令以人为地挂起设备,可能会明确触发。

请注意,这不是一个最优或完全通用的内核,因为块大小必须完全适合VMEM,并且我们可以更好地交错通信和累积。在后面的部分中,我们将讨论这些优化。

partition = P(None, 'x')
devices = mesh_utils.create_device_mesh((1, num_devices))
mesh = jax.sharding.Mesh(devices, partition)
sharding = jax.sharding.NamedSharding(mesh, partition)

input_arr = jax.random.uniform(jax.random.key(0), shape=(8, 128 * num_devices))
input_arr = jax.device_put(input_arr, sharding)


def all_reduce_kernel(
    x_ref,
    o_ref,
    hbm_scratch,
    copy_sem,
    remote_recv_sem,
    remote_send_sem,
    capacity_sem,
    receive_scratch,
):
  outer_step = pl.program_id(0)
  working_slot = lax.rem(outer_step, 2)
  receiving_slot = 1 - working_slot

  my_id = lax.axis_index('x')
  right_neighbor = lax.rem(my_id + 1, num_devices)
  left_neighbor = lax.rem(my_id - 1 + num_devices, num_devices)

  @pl.when(outer_step == 0)
  def _():
    # 一开始就与两个邻居为敌,因为我们将会
    # 与两者进行沟通。
    barrier_sem = pltpu.get_barrier_semaphore()
    pltpu.semaphore_signal(
        barrier_sem,
        inc=1,
        device_id=(0, left_neighbor),
        device_id_type=pltpu.DeviceIdType.MESH,
    )
    pltpu.semaphore_signal(
        barrier_sem,
        inc=1,
        device_id=(0, right_neighbor),
        device_id_type=pltpu.DeviceIdType.MESH,
    )
    pltpu.semaphore_wait(barrier_sem, 2)

    # 初始化o_ref、acc_scratch和hbm_scratch。
    o_ref[...] = jnp.zeros_like(o_ref)
    receive_scratch[...] = jnp.zeros_like(receive_scratch)
    initial_copy = pltpu.make_async_remote_copy(
        src_ref=x_ref,
        dst_ref=hbm_scratch.at[working_slot],
        send_sem=remote_send_sem,
        recv_sem=remote_recv_sem,
        device_id=(0, right_neighbor),
        device_id_type=pltpu.DeviceIdType.MESH,
    )
    initial_copy.start()
    initial_copy.wait()

  # 向左邻示意,我方已做好接收准备。
  # 如果没有这个信号,我们的左邻居可能会领先超过一次迭代,
  # 这意味着它可以写入我们的工作槽。
  pltpu.semaphore_signal(
      capacity_sem,
      inc=1,
      device_id=(0, left_neighbor),
      device_id_type=pltpu.DeviceIdType.MESH,
  )

  # 将左邻发送给我们的部分结果复制到VMEM中
  # 计算。
  local_copy = pltpu.make_async_copy(
      src_ref=hbm_scratch.at[working_slot],
      dst_ref=receive_scratch,
      sem=copy_sem,
  )
  local_copy.start()

  # 阻塞直到我们的右邻准备好接收。
  pltpu.semaphore_wait(capacity_sem, 1)
  # 将值传递给右边的邻居。
  remote_copy = pltpu.make_async_remote_copy(
      src_ref=hbm_scratch.at[working_slot],
      dst_ref=hbm_scratch.at[receiving_slot],
      send_sem=remote_send_sem,
      recv_sem=remote_recv_sem,
      device_id=(0, right_neighbor),
      device_id_type=pltpu.DeviceIdType.MESH,
  )
  remote_copy.start()
  # 完成本地副本并累积,同时远程副本正在进行中。
  local_copy.wait()
  o_ref[...] += receive_scratch[...]
  # 阻塞直到远程复制完成。
  remote_copy.wait()


out_shape = (
    jax.ShapeDtypeStruct((8, 128), jnp.float32),
    # 我们将双缓冲区分配为Pallas输出,以便其
    # 居住在HBM。
    jax.ShapeDtypeStruct((2, 8, 128), jnp.float32),  # hbm_scratch
)

grid_spec = pltpu.PrefetchScalarGridSpec(
    num_scalar_prefetch=0,
    in_specs=[
        # 我们的输入存在于VMEM中。
        pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
    ],
    out_specs=[
        # 我们的输出存在于虚拟内存中。
        pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
        # 我们的双缓冲区位于HBM中。
        pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
    ],
    grid=(num_devices,),
    scratch_shapes=(
        [pltpu.SemaphoreType.DMA] * 3
        + [pltpu.SemaphoreType.REGULAR]  # 容量半
        + [pltpu.VMEM((8, 128), jnp.float32)]  # 接收划痕
    ),
)

kernel = pl.pallas_call(
    all_reduce_kernel,
    out_shape=out_shape,
    grid_spec=grid_spec,
    compiler_params=pltpu.TPUCompilerParams(collective_id=0),
)

pallas_result = jax.jit(
    shard_map.shard_map(
        kernel,
        mesh=mesh,
        in_specs=partition,
        out_specs=partition,
        check_rep=False,
    )
)(input_arr)
pallas_result = jax.block_until_ready(pallas_result)[0]


def lax_sum(x):
  return lax.psum(x, 'x')


xla_result = jax.jit(
    shard_map.shard_map(
        lax_sum, mesh=mesh, in_specs=P(None, 'x'), out_specs=P(None, 'x')
    )
)(input_arr)

print('Input = ', input_arr[0, ::128])
print('Pallas result = ', pallas_result[0, ::128])
print('lax.psum result = ', xla_result[0, ::128])
difference = jnp.mean(jnp.abs(pallas_result - xla_result))
print('Difference |Pallas - lax.psum| = ', difference)
Input =  [0.9858954  0.11763906 0.9955574  0.775211  ]
Pallas result =  [2.8743029 2.8743029 2.8743029 2.8743029]
lax.psum result =  [2.8743029 2.8743029 2.8743029 2.8743029]
Difference |Pallas - lax.psum| =  1.4959369e-08

预运行和竞态条件#

作为一个一般性的经验法则,为了最大化性能,我们希望允许一个设备在不进行同步的情况下,尽可能地超前于其他设备,而不牺牲程序的正确性。虽然我们可以在每次迭代的开始强制所有设备之间进行屏障同步,但这会导致程序在每个循环中被阻碍到最慢的设备。通过放宽同步并允许适度的超前,我们可以更好地适应迭代与设备之间的延迟差异,因为一个在某次迭代中较慢的设备可以在下一次迭代中赶上。

在我们之前写的全归约内核中,我们允许设备超前,但比其邻居少一个迭代(然而,非邻接设备之间可以相差超过1个迭代)。要理解为什么需要信号量同步,考虑一下当一个设备(比如设备2)挂起并落后于其他设备的情况。RDMA没有“握手”——只有接收者在等待数据到达时被阻塞。因此,每个设备最多可以超前一个迭代,然后就会因等待下一个RDMA到达而被阻塞。如果我们有N个设备,这意味着最后一个设备最多可以比第一个设备超前N个迭代。

race_condition

如果不在另一个方向上添加同步(强制发送者阻塞),设备1有可能会比设备2超前多达N个迭代(N = num_devices),在此过程中发送多个写操作并覆盖已存在的值。为了解决我们之前编写的all_reduce内核中的这个问题,我们实现了一个“握手”协议,其中接收者向发送者发出信号表示它已经准备好接收,然后发送者才开始发出下一个RDMA。

双向通信#

在我们之前的内核中,我们在一个环上以从左到右的单向方式进行通信。然而,由于 ICI 连接是双向的,我们实际上通过没有从右到左发送值而浪费了总带宽的一半。在下一个内核中,我们将演示一个在两个方向上进行通信的示例,以最大化 ICI 带宽。

示例:双向归约散布(lax.psum_scatter#

归约散布操作是全归约和散布的组合。或者,换句话说,全归约是归约散布和全收集的组合。

下图描述了此操作的语义。我们假设每个设备开始时都有一组部分和(用字母+数字表示,例如A0)。目标是在一个轴上进行归约(数字),同时在另一个轴上进行分片(字母)。

reduce_scatter_1

为了实现双向通信策略,我们将每个输入块切成两半,并为每一半指定一个方向。每个块的上半部分将从右向左传递,而下半部分将从左向右传递。与我们之前的全归约和全收集内核的通信模式的第二个偏差是,我们还将传递累加器或部分和,并保持输入在每个设备本地。这与之前的示例形成对比,在这些示例中,我们传递输入但保持累加器本地。传递累加器对于这个问题更为自然,因为与全归约相比,输入中的大部分数据并不是最终会存储在设备上的输出的一部分。(例如,上图中的B0C0D0不会在操作结束时存储在持有A的设备上)。

下图说明了这种通信模式,其中彩色框表示累加器(而不是输入!)。最初,累加器仅仅是输入中包含的值。在算法的每次迭代中,我们将从每个方向的邻居那里接收一个部分和。然后,我们计算正确的输入片段以累加到部分缓冲区中,然后将新的部分和传递给下一个邻居。经过N次迭代后,累加器将通过每个设备,这意味着它最终将拥有完整的和。

reduce_scatter_2

在内核的构建方面,我们在Pallas网格中引入一个额外的phase维度,用于表示我们当前正在计算哪个累加器(左或右)。我们让phase=0表示向左移动的累加器,而phase=1表示向右移动的累加器。然后我们对这两个相位进行流水线处理,在计算一个相位的结果时,我们在相反方向传输之前计算的值,为下一个相位做好准备。例如,当我们处于phase=0(左)时,我们首先开始进行DMA,将我们在前一次迭代中计算的结果传输到右邻居(右DMA)。然后,我们累加到左缓冲区并将结果保存到HBM。然后,我们等待右DMA完成,以便它为phase=1(右)做好准备。

partition = P(None, 'x')
devices = mesh_utils.create_device_mesh((1, num_devices))
mesh = jax.sharding.Mesh(devices, partition)
sharding = jax.sharding.NamedSharding(mesh, partition)

# 我们需要一个大小为 (16, 128) 的块,以确保半切片至少
# 大小为 (8, 128),这正是 VREG 的大小。这使得分块操作更加简便。
# 对于编译器。
block_size = (16, 128)
input_arr = jax.random.uniform(
    jax.random.key(0),
    shape=(block_size[0] * num_devices, block_size[1] * num_devices),
)
input_arr = jax.device_put(input_arr, sharding)

LEFT = 0
RIGHT = 1


def mod(x, n):
  return lax.rem(x + n, n)


def signal(left_or_right, semaphore):
  my_id = lax.axis_index('x')
  if left_or_right == LEFT:
    neighbor = mod(my_id - 1, num_devices)
  else:
    neighbor = mod(my_id + 1, num_devices)
  pltpu.semaphore_signal(
      semaphore,
      inc=1,
      device_id=(0, neighbor),
      device_id_type=pltpu.DeviceIdType.MESH,
  )


def reduce_scatter_kernel(
    x_ref,
    o_ref,
    hbm_scratch,
    local_copy_sem,
    left_recv_sem,
    left_send_sem,
    right_recv_sem,
    right_send_sem,
    left_capacity_sem,
    right_capacity_sem,
    accum_scratch,
):
  outer_step = pl.program_id(0)
  phase = pl.program_id(1)
  is_start = jnp.logical_and(outer_step == 0, phase == 0)
  last_iteration = outer_step == pl.num_programs(0) - 1

  working_slot = lax.rem(outer_step, 2)
  receiving_slot = 1 - working_slot
  my_id = lax.axis_index('x')
  right_neighbor = mod(my_id + 1, num_devices)
  left_neighbor = mod(my_id - 1, num_devices)

  left_copy_device = mod(my_id + outer_step + 1, num_devices)
  right_copy_device = mod(my_id - outer_step - 1, num_devices)
  # 可以使用 pl.ds(start, size) 来指定切片
  left_copy_slice = pl.ds(0, block_size[0] // 2)
  right_copy_slice = pl.ds(block_size[0] // 2, block_size[0] // 2)
  current_phase_slice = pl.ds(phase * (block_size[0] // 2), block_size[0] // 2)

  initial_left_copy = pltpu.make_async_remote_copy(
      src_ref=x_ref.at[my_id, left_copy_slice],
      dst_ref=hbm_scratch.at[working_slot, left_copy_slice],
      send_sem=left_send_sem,
      recv_sem=left_recv_sem,
      device_id=(0, left_neighbor),
      device_id_type=pltpu.DeviceIdType.MESH,
  )

  initial_right_copy = pltpu.make_async_remote_copy(
      src_ref=x_ref.at[my_id, right_copy_slice],
      dst_ref=hbm_scratch.at[working_slot, right_copy_slice],
      send_sem=right_send_sem,
      recv_sem=right_recv_sem,
      device_id=(0, right_neighbor),
      device_id_type=pltpu.DeviceIdType.MESH,
  )

  left_copy = pltpu.make_async_remote_copy(
      src_ref=hbm_scratch.at[working_slot, left_copy_slice],
      dst_ref=hbm_scratch.at[receiving_slot, left_copy_slice],
      send_sem=left_send_sem,
      recv_sem=left_recv_sem,
      device_id=(0, left_neighbor),
      device_id_type=pltpu.DeviceIdType.MESH,
  )
  right_copy = pltpu.make_async_remote_copy(
      # 注意:由于我们正在进行复制操作,右侧副本相对于插槽是翻转的。
      # 到下一个外层迭代步骤。
      src_ref=hbm_scratch.at[receiving_slot, right_copy_slice],
      dst_ref=hbm_scratch.at[working_slot, right_copy_slice],
      send_sem=right_send_sem,
      recv_sem=right_recv_sem,
      device_id=(0, right_neighbor),
      device_id_type=pltpu.DeviceIdType.MESH,
  )

  # --- 序章 ---
  @pl.when(is_start)
  def _():
    # 一开始就与两个邻居形成屏障,因为我们将会
    # 与两者进行沟通。
    barrier_sem = pltpu.get_barrier_semaphore()
    pltpu.semaphore_signal(
        barrier_sem,
        inc=1,
        device_id=(0, left_neighbor),
        device_id_type=pltpu.DeviceIdType.MESH,
    )
    pltpu.semaphore_signal(
        barrier_sem,
        inc=1,
        device_id=(0, right_neighbor),
        device_id_type=pltpu.DeviceIdType.MESH,
    )
    pltpu.semaphore_wait(barrier_sem, 2)

    # 使用初始副本初始化 o_ref、acc_scratch 和 hbm_scratch。
    o_ref[...] = jnp.zeros_like(o_ref[...])
    accum_scratch[...] = jnp.zeros_like(accum_scratch[...])

    initial_left_copy.start()
    initial_left_copy.wait()
    initial_right_copy.start()

    # 我们告诉左边的邻居,允许它向右边发送消息。
    # (反之亦然,对于右邻)
    signal(LEFT, right_capacity_sem)
    signal(RIGHT, left_capacity_sem)

  # --- 正文 ---
  # 在我们内核体的起始部分,我们启动了一个DMA,用于复制
  # 将我们在上一阶段计算得到的结果发送给我们的邻居。
  # 这使得我们能够重叠发送上一阶段的通信
  # 与当前阶段的计算。
  @pl.when(~is_start)
  def _():
    @pl.when(phase == LEFT)
    def _():
      # 我们在此阻塞,直到右邻告知我们可以发送为止。
      # 右边。
      pltpu.semaphore_wait(right_capacity_sem, 1)
      right_copy.start()

    @pl.when(phase == RIGHT)
    def _():
      # 我们在此阻塞,直到左邻告知我们可以发送为止。
      # 左边。
      pltpu.semaphore_wait(left_capacity_sem, 1)
      left_copy.start()

  local_copy = pltpu.make_async_copy(
      src_ref=hbm_scratch.at[working_slot, current_phase_slice],
      dst_ref=accum_scratch,
      sem=local_copy_sem,
  )
  local_copy.start()
  local_copy.wait()

  @pl.when(~last_iteration)
  def _():
    @pl.when(phase == LEFT)
    def _():
      accum_scratch[...] += x_ref[left_copy_device, left_copy_slice]

    @pl.when(phase == RIGHT)
    def _():
      accum_scratch[...] += x_ref[right_copy_device, right_copy_slice]

  local_copy = pltpu.make_async_copy(
      src_ref=accum_scratch,
      dst_ref=hbm_scratch.at[working_slot, current_phase_slice],
      sem=local_copy_sem,
  )
  local_copy.start()
  local_copy.wait()

  @pl.when(is_start)
  def _():
    initial_right_copy.wait()

  # 在我们的内核体末尾,我们等待前一阶段的DMA完成。
  # 以确保结果为下一阶段做好准备。
  @pl.when(~is_start)
  def _():
    @pl.when(phase == LEFT)
    def _():
      right_copy.wait()
      signal(LEFT, right_capacity_sem)

    @pl.when(phase == RIGHT)
    def _():
      left_copy.wait()
      signal(RIGHT, left_capacity_sem)

  # --- 尾声 ---
  # 存储最后一次迭代的结果。
  @pl.when(last_iteration)
  def _():
    # 清理信号量,使其以值0退出。
    @pl.when(phase == LEFT)
    def _():
      o_ref[left_copy_slice, ...] = accum_scratch[...]
      pltpu.semaphore_wait(right_capacity_sem, 1)

    @pl.when(phase == RIGHT)
    def _():
      o_ref[right_copy_slice, ...] = accum_scratch[...]
      pltpu.semaphore_wait(left_capacity_sem, 1)


out_shape = (
    jax.ShapeDtypeStruct((block_size[0], block_size[1]), jnp.float32),  # 输出
    # 形状:[工作/接收,块[0],块[1]]
    jax.ShapeDtypeStruct(
        (2, block_size[0], block_size[1]), jnp.float32
    ),  # hbm_暂存区
)

grid_spec = pltpu.PrefetchScalarGridSpec(
    num_scalar_prefetch=0,
    in_specs=[
        pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
    ],
    out_specs=[
        pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
        pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
    ],
    grid=(num_devices, 2),
    scratch_shapes=(
        [pltpu.SemaphoreType.DMA] * 5
        + [pltpu.SemaphoreType.REGULAR] * 2  # 容量信号量
        + [
            pltpu.VMEM((block_size[0] // 2, block_size[1]), jnp.float32)
        ]  # 累加暂存
    ),
)


def pallas_reduce_scatter(input_arr):
  input_arr = input_arr.reshape(num_devices, block_size[0], block_size[1])
  return pl.pallas_call(
      reduce_scatter_kernel,
      out_shape=out_shape,
      grid_spec=grid_spec,
      compiler_params=pltpu.TPUCompilerParams(collective_id=0),
  )(input_arr)[0]


pallas_result = jax.jit(
    shard_map.shard_map(
        pallas_reduce_scatter,
        mesh=mesh,
        in_specs=P(None, 'x'),
        out_specs=P('x', None),
        check_rep=False,
    )
)(input_arr)

pallas_result = jax.block_until_ready(pallas_result)
# 将我们的结果与XLA进行比较。
def lax_reduce_sum_scatter(x):
  x = x.reshape(num_devices, block_size[0], block_size[1])
  return lax.psum_scatter(x, 'x')


xla_result = jax.jit(
    shard_map.shard_map(
        lax_reduce_sum_scatter,
        mesh=mesh,
        in_specs=P(None, 'x'),
        out_specs=P('x', None),
    )
)(input_arr)

print('Input:', input_arr.shape, input_arr[::4, 0])
print('Pallas Result:', pallas_result.shape, pallas_result[::4, 0])
print('lax.psum_scatter Result:', xla_result.shape, xla_result[::4, 0])
print(
    'Difference |Pallas - lax.psum_scatter|:',
    jnp.max(jnp.abs(pallas_result - xla_result)),
)
Input: (64, 512) [0.78051674 0.3524047  0.59993696 0.9714314  0.24692321 0.01347649
 0.01857424 0.24841607 0.86097646 0.8261659  0.9753758  0.6902338
 0.4431417  0.963323   0.3158517  0.535548  ]
Pallas Result: (64, 128) [1.3593563 1.6274805 1.0979297 3.082869  1.4194957 1.4163033 1.2401303
 1.1892898 2.6545286 2.221559  2.7995253 2.08431   2.2509837 3.0726733
 2.4662397 1.9542246]
lax.psum_scatter Result: (64, 128) [1.3593563 1.6274805 1.0979297 3.082869  1.4194957 1.4163033 1.2401303
 1.1892898 2.6545286 2.221559  2.7995253 2.08431   2.2509837 3.0726733
 2.4662397 1.9542246]
Difference |Pallas - lax.psum_scatter|: 2.3841858e-07

嵌套的远程和本地 DMA 流水线#

我们之前编写的所有归约和归约散发内核的一个限制是,我们通过远程 DMA 复制的块必须足够小,以适应我们用于积累的工作 VMEM。对于某些内核,使用更大的块大小可能更有利于更好地利用 TPU。例如,矩阵乘法需要大约 \(O(N^3)\) 的计算操作,但仅需要 \(O(N^2)\) 的内存传输。因此,我们希望在设备之间转移的每个工作块都足够大,以使操作变为计算限制,并且我们可以利用流水线隐藏通信成本。作为参考,TPU 的 VMEM(第 v4/v5 代)典型大小在 10-100MB 之间,而 HBM 的范围为 10-100GB。

为了解决这个问题,我们需要能够编写一个“内部内核”,它处理“外部内核”内的本地 HBM-VMEM 流水线,该外部内核处理设备之间更大的 HBM-HBM 传输的流水线。Pallas 提供了一个 API,用于使用 emit_pipeline 函数构建嵌套流水线。emit_pipeline 的基本调用签名遵循标准的 pallas_call,通过指定输入和输出的 gridBlockSpec:

def emit_pipeline(
    kernel: Callable,
    grid: tuple[int],
    in_specs: PyTree[BlockSpec] = None,
    out_specs: PyTree[BlockSpec] = None,
    should_accumulate_out: bool = False,
    dimension_semantics: tuple[GridDimensionSemantics] = None,
) -> Callable:
  ... # 返回一个给定内部内核和 BlockSpecs 的自定义流水线。

实际上,可以将 pallas_call 本身视为仅仅是 emit_pipeline 的一个包装器。由于我们的外部内核仅涉及远程 HBM-HBM 传输,因此我们未使用 pallas_call 为 HBM-VMEM 传输提供的任何内置流水线。以下代码框架演示了使用此模式的典型程序结构:


def outer_kernel(...):
  # ... 执行负责流水线远程 HBM-HBM 传输的工作(外部内核)

  def inner_kernel(...):
    # ... 执行工作(内部内核)
  pltpu.emit_pipeline(
          inner_kernel,
          grid=inner_grid,
          in_specs=...,
          out_specs=...,
  )(inner_kernel_args)
  # ... 执行更多工作(外部内核)

pl.pallas_call(
  outer_kernel,
  grid=outer_grid,
  in_specs=...
  out_specs=...
  scratch=inner_kernel_allocs
)

示例:使用大HBM块进行Reduce-Scatter#

在下一个示例中,我们将修改之前的reduce-scatter示例,以利用嵌套的内部管道。请注意,reduce_scatter的通信和计算成本都随着输入大小的增加而线性增加,因此我们不一定期望在较大的块大小下,该操作变得计算密集。本示例纯粹是为了演示如何使用管道发射器。

我们将增加外部内核的块大小,使其不适合放入VMEM,并在HBM中分配所有输入和输出(memory_space=TPUMemorySpace.Any)。与之前的内核相比,唯一的主要变化是在进行累积的内核主体。我们不再手动将数据从HBM复制到VMEM进行累积,然后再复制回HBM,而是使用emit_pipeline来处理内存传输。累积在内部内核中以更小、更适合VMEM的块大小进行。

在我们之前的内核中,有以下内核主体,用于从HBM复制数据到VMEM累积器,进行增量,然后将结果再次复制到HBM:

local_copy = pltpu.make_async_copy(
    src_ref=hbm_scratch.at[working_slot, current_phase_slice],
    dst_ref=accum_scratch,
    sem=local_copy_sem,
)
local_copy.start()
local_copy.wait()
@pl.when(~last_iteration)
def _():
  @pl.when(phase == LEFT)
  def _():
    accum_scratch[...] += x_ref[left_copy_device, left_copy_slice]
  @pl.when(phase == RIGHT)
  def _():
    accum_scratch[...] += x_ref[right_copy_device, right_copy_slice]
local_copy = pltpu.make_async_copy(
    src_ref=accum_scratch,
    dst_ref=hbm_scratch.at[working_slot, current_phase_slice],
    sem=local_copy_sem,
)
local_copy.start()
local_copy.wait()

我们的新内核用以下emit_pipeline调用替换了它:

def inner_kernel(input_ref, accum_ref):
  accum_ref[...] = input_ref[...]
accum_pipeline = pltpu.emit_pipeline(inner_kernel,
                                     in_specs=[inner_block_spec],
                                     out_specs=inner_block_spec,
                                     should_accumulate_out=True,
                                     grid=inner_grid)
@pl.when(~last_iteration)
def _():
  @pl.when(phase == LEFT)
  def _():
    accum_pipeline(x_ref.at[left_copy_device, left_copy_slice],
                   hbm_scratch.at[working_slot, left_copy_slice],
    )
  @pl.when(phase == RIGHT)
  def _():
    accum_pipeline(x_ref.at[right_copy_device, right_copy_slice],
                   hbm_scratch.at[working_slot, right_copy_slice],
    )

完整的内核如下:

partition = P(None, 'x')
devices = mesh_utils.create_device_mesh((1, num_devices))
mesh = jax.sharding.Mesh(devices, partition)
sharding = jax.sharding.NamedSharding(mesh, partition)

# 我们选择了一个较大的外部内核块大小,我们不希望放置
# 在VMEM中,为了教学目的,我们使用(4096, 4096),尽管在
# 原则上,这个可以大得多。
outer_block_size = (4096, 4096)
# 我们为内部内核选择了一个较小的VMEM块大小。
inner_block_size = (128, 128)
input_arr = jax.random.uniform(
    jax.random.key(0),
    shape=(
        outer_block_size[0] * num_devices,
        outer_block_size[1] * num_devices,
    ),
)
input_arr = jax.device_put(input_arr, sharding)


inner_grid = (
    outer_block_size[0] // inner_block_size[0] // 2,
    outer_block_size[1] // inner_block_size[1],
)
inner_block_spec = pl.BlockSpec(
    index_map=lambda i, j: (i, j),
    block_shape=inner_block_size,
    memory_space=pltpu.TPUMemorySpace.ANY,
)


def reduce_scatter_kernel(
    x_ref,
    o_ref,
    hbm_scratch,
    left_recv_sem,
    left_send_sem,
    copy_sem,
    right_recv_sem,
    right_send_sem,
    left_capacity_sem,
    right_capacity_sem,
):
  outer_step = pl.program_id(0)
  phase = pl.program_id(1)
  is_start = jnp.logical_and(outer_step == 0, phase == 0)
  last_iteration = outer_step == pl.num_programs(0) - 1

  working_slot = lax.rem(outer_step, 2)
  receiving_slot = 1 - working_slot
  my_id = lax.axis_index('x')
  right_neighbor = mod(my_id + 1, num_devices)
  left_neighbor = mod(my_id - 1, num_devices)

  left_copy_device = mod(my_id + outer_step + 1, num_devices)
  right_copy_device = mod(my_id - outer_step - 1, num_devices)
  left_copy_slice = pl.ds(0, outer_block_size[0] // 2)
  right_copy_slice = pl.ds(outer_block_size[0] // 2, outer_block_size[0] // 2)
  current_phase_slice = pl.ds(
      phase * (outer_block_size[0] // 2), outer_block_size[0] // 2
  )

  initial_left_copy = pltpu.make_async_remote_copy(
      src_ref=x_ref.at[my_id, left_copy_slice],
      dst_ref=hbm_scratch.at[working_slot, left_copy_slice],
      send_sem=left_send_sem,
      recv_sem=left_recv_sem,
      device_id=(0, left_neighbor),
      device_id_type=pltpu.DeviceIdType.MESH,
  )

  initial_right_copy = pltpu.make_async_remote_copy(
      src_ref=x_ref.at[my_id, right_copy_slice],
      dst_ref=hbm_scratch.at[working_slot, right_copy_slice],
      send_sem=right_send_sem,
      recv_sem=right_recv_sem,
      device_id=(0, right_neighbor),
      device_id_type=pltpu.DeviceIdType.MESH,
  )

  left_copy = pltpu.make_async_remote_copy(
      src_ref=hbm_scratch.at[working_slot, left_copy_slice],
      dst_ref=hbm_scratch.at[receiving_slot, left_copy_slice],
      send_sem=left_send_sem,
      recv_sem=left_recv_sem,
      device_id=(0, left_neighbor),
      device_id_type=pltpu.DeviceIdType.MESH,
  )
  right_copy = pltpu.make_async_remote_copy(
      src_ref=hbm_scratch.at[receiving_slot, right_copy_slice],
      dst_ref=hbm_scratch.at[working_slot, right_copy_slice],
      send_sem=right_send_sem,
      recv_sem=right_recv_sem,
      device_id=(0, right_neighbor),
      device_id_type=pltpu.DeviceIdType.MESH,
  )

  # --- 序章 ---
  @pl.when(is_start)
  def _():
    # 一开始就与两个邻居为敌,因为我们将
    # 与两者进行沟通。
    barrier_sem = pltpu.get_barrier_semaphore()
    pltpu.semaphore_signal(
        barrier_sem,
        inc=1,
        device_id=(0, left_neighbor),
        device_id_type=pltpu.DeviceIdType.MESH,
    )
    pltpu.semaphore_signal(
        barrier_sem,
        inc=1,
        device_id=(0, right_neighbor),
        device_id_type=pltpu.DeviceIdType.MESH,
    )
    pltpu.semaphore_wait(barrier_sem, 2)

    initial_left_copy.start()
    initial_left_copy.wait()
    initial_right_copy.start()

    # 我们告诉左边的邻居,它被允许向右边发送消息。
    # (反之亦然,对于右邻)
    signal(LEFT, right_capacity_sem)
    signal(RIGHT, left_capacity_sem)

  @pl.when(~is_start)
  def _():
    @pl.when(phase == LEFT)
    def _():
      # 我们在此阻塞,直到右邻告知我们可以发送为止。
      # 右边。
      pltpu.semaphore_wait(right_capacity_sem, 1)
      right_copy.start()

    @pl.when(phase == RIGHT)
    def _():
      # 我们在此阻塞,直到左邻告知我们可以发送为止。
      # 左边。
      pltpu.semaphore_wait(left_capacity_sem, 1)
      left_copy.start()

  # --- 正文 ---
  def inner_kernel(input_ref, accum_ref):
    # 我们没有显式地使用 +=,因为我们设置了 should_accumulate_out=True。
    accum_ref[...] = input_ref[...]

  accum_pipeline = pltpu.emit_pipeline(
      inner_kernel,
      in_specs=[inner_block_spec],
      out_specs=inner_block_spec,
      should_accumulate_out=True,
      grid=inner_grid,
  )

  @pl.when(~last_iteration)
  def _():
    @pl.when(phase == LEFT)
    def _():
      accum_pipeline(
          x_ref.at[left_copy_device, left_copy_slice],
          hbm_scratch.at[working_slot, left_copy_slice],
      )

    @pl.when(phase == RIGHT)
    def _():
      accum_pipeline(
          x_ref.at[right_copy_device, right_copy_slice],
          hbm_scratch.at[working_slot, right_copy_slice],
      )

  # --- 尾声 ---
  @pl.when(is_start)
  def _():
    initial_right_copy.wait()

  @pl.when(~is_start)
  def _():
    @pl.when(phase == LEFT)
    def _():
      right_copy.wait()
      signal(LEFT, right_capacity_sem)

    @pl.when(phase == RIGHT)
    def _():
      left_copy.wait()
      signal(RIGHT, left_capacity_sem)

  # 存储最后一次迭代的结果。
  @pl.when(last_iteration)
  def _():
    output_copy = pltpu.make_async_copy(
        src_ref=hbm_scratch.at[working_slot, current_phase_slice],
        dst_ref=o_ref.at[current_phase_slice],
        sem=copy_sem,
    )
    output_copy.start()
    output_copy.wait()

    # 清理信号量,使其以0值退出。
    @pl.when(phase == LEFT)
    def _():
      pltpu.semaphore_wait(right_capacity_sem, 1)

    @pl.when(phase == RIGHT)
    def _():
      pltpu.semaphore_wait(left_capacity_sem, 1)


out_shape = (
    jax.ShapeDtypeStruct(
        (outer_block_size[0], outer_block_size[1]), jnp.float32
    ),
    # 形状:[工作/接收,块[0],块[1]]
    jax.ShapeDtypeStruct(
        (2, outer_block_size[0], outer_block_size[1]), jnp.float32
    ),  # hbm_scratch
)

grid_spec = pltpu.PrefetchScalarGridSpec(
    num_scalar_prefetch=0,
    in_specs=[
        pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
    ],
    out_specs=[
        pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
        pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
    ],
    grid=(num_devices, 2),
    scratch_shapes=(
        [pltpu.SemaphoreType.DMA] * 5
        + [pltpu.SemaphoreType.REGULAR] * 2  # 容量信号量
    ),
)


def pallas_reduce_scatter(input_arr):
  input_arr = input_arr.reshape(
      num_devices, outer_block_size[0], outer_block_size[1]
  )
  return pl.pallas_call(
      reduce_scatter_kernel,
      out_shape=out_shape,
      grid_spec=grid_spec,
      compiler_params=pltpu.TPUCompilerParams(collective_id=0),
  )(input_arr)[0]


pallas_result = jax.jit(
    shard_map.shard_map(
        pallas_reduce_scatter,
        mesh=mesh,
        in_specs=P(None, 'x'),
        out_specs=P('x', None),
        check_rep=False,
    )
)(input_arr)

pallas_result = jax.block_until_ready(pallas_result)
# 现在我们将结果与XLA进行比较。
def lax_reduce_sum_scatter(x):
  x = x.reshape(num_devices, outer_block_size[0], outer_block_size[1])
  return lax.psum_scatter(x, 'x')


xla_result = jax.jit(
    shard_map.shard_map(
        lax_reduce_sum_scatter,
        mesh=mesh,
        in_specs=P(None, 'x'),
        out_specs=P('x', None),
    )
)(input_arr)

print('Input:', input_arr.shape, input_arr[::4, 0])
print('Pallas Result:', pallas_result.shape, pallas_result[::4, 0])
print('lax.psum_scatter Result:', xla_result.shape, xla_result[::4, 0])
print(
    'Difference |Pallas - lax.psum_scatter|:',
    jnp.max(jnp.abs(pallas_result - xla_result)),
)
Input: (16384, 16384) [0.74162567 0.0242182  0.27751946 ... 0.05213022 0.36088037 0.04494429]
Pallas Result: (16384, 4096) [2.0648427 1.674587  1.9148926 ... 1.3371865 1.3296283 1.2887063]
lax.psum_scatter Result: (16384, 4096) [2.0648427 1.674587  1.9148926 ... 1.3371865 1.3296283 1.2887063]
Difference |Pallas - lax.psum_scatter|: 2.3841858e-07

最终说明#

Megacore#

某些TPU包含在Megacore配置中的多个核心。在这种配置下,我们的总体建议是仅从单个核心启动DMA,仅执行HBM-HBM传输。要做到这一点,可以将网格轴之一设置为核心数量(可以通过jax.devices()[0].num_cores获取)并将dimension_semantics设置为“parallel”。然后,您可以使用core_index = pl.program_id(axis)获取该轴上的核心索引,并使用@pl.when(core_index==i)执行特定于该核心的代码。

与XLA的交互#

在本教程中,我们涵盖了几个内核示例,这些示例复制了JAX中集体操作的功能,如lax.all_gatherlax.psumlax.psum_scatter。需要注意的一个重要警告是,Pallas内核对XLA编译器而言是有些模糊的,可能导致它无法执行一些通常会执行的优化。例如,XLA可以异步派发集体操作,以便在不编写自定义内核的情况下交错通信和计算。当涉及Pallas内核时,这并不保证会发生,因此重要的是对您的程序进行性能分析,以查看是否存在此问题。另一个例子是,我们在本教程中使用的emit_pipeline函数生成嵌套管道,但它对XLA编译器是不可见的,因此无法与相邻操作融合。

下一步#

读者可以进行的优秀后续练习包括实现分布式矩阵乘法、实现lax.all_to_all以及放松同步以允许额外的提前运行。