流水线#

在本指南中,我们将讨论TPU中的内存空间如何工作,以及如何在Pallas中编写管道,以便将内存输入/输出与计算重叠。

#@title 导入

import jax
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
import jax.numpy as jnp
import numpy as np

TPU及其内存空间#

TPU及其TensorCore由内存空间(数组驻留的地方)、寄存器(临时存储标量和数组值)和计算单元(使用寄存器中的值进行计算)组成。下面是一个TPU的示意图,其中xy是驻留在高带宽内存(HBM)中的数组:

TPU内存空间示意图.png

让我们详细讨论这个图中的组件:

  • 内存空间:TPU具有高带宽内存(HBM),这通常被认为是“设备内存”。还有向量内存(VMEM),用于存储向量和数组值的缓存,还有标量内存(SMEM),用于存储标量值的缓存。

  • **寄None

让我们实现一个Pallas函数,正是这样!

def add_matrices_kernel(x_vmem_ref, y_vmem_ref, z_vmem_ref):
  # 从VMEM加载x和y到VREGs
  x_vregs = x_vmem_ref[:, :]
  y_vregs = y_vmem_ref[:, :]
  # 执行向量加法
  z_vregs = x_vregs + y_vregs
  # 将输出值从VREG寄存器存储回VMEM中
  z_vmem_ref[:, :] = z_vregs


def add_matrices(x: jax.Array, y: jax.Array) -> jax.Array:
  # pallas_call 将首先在 VMEM 中为 `x` 和 `y` 分配临时缓冲区。
  # 随后,它会将`x`和`y`从HBM复制到VMEM中。
  z = pl.pallas_call(
      add_matrices_kernel, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)
  )(x, y)
  # pallas_call 还会将 VMEM 的输出复制回 HBM。
  return z


x, y = jnp.ones((512, 512)), jnp.ones((512, 512))
add_matrices(x, y)
Array([[2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       ...,
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.]], dtype=float32)

我们编写了两个函数:add_matrices_kerneladd_matrices

add_matrices_kernel 使用存活在 VMEM 中的 Ref 进行操作。 从 VMEM 的 Ref 加载会产生一个存活在 VREG 的值。 VREG 中的值表现得像 jax.Array,因此我们可以对它们使用 jnpjax.lax 操作来生成新的存活在 VREG 中的值。 当我们生成想要返回的值时,我们将它们存储在输出的 VMEM Ref 中。

add_matrices 函数对 jax.Array 进行操作并返回一个 jax.Array。 在内部,我们将 xy 传递给 pallas_callpallas_call 负责将 xy 复制到 VMEM 中,并为内核操作分配 VMEM 缓冲区(包括分配 z_vmem_ref,即输出的 VMEM 缓冲区)。 内核函数执行完毕后,pallas_call 还会将 z_vmem_ref 中的值复制到 HBM 中,从而产生一个输出的 jax.Array

使用VMEM/SMEM的限制#

Pallas提供对低级内存空间的访问,如VMEM和SMEM,但使用它们编写内核时需要考虑一些因素。

  1. 内存容量。VMEM和SMEM是小的! v4 TPU上的VMEM仅为16MiB,而SMEM的范围在几十到几百KiB之间。如果我们的数组太大,根本无法将它们放入VMEM中。作为参考,一个f32[2048, 2048]的数组是16MiB,因此我们上面的内核在中等大小数组之外无法扩展。

  2. 内存带宽。从HBM和VMEM之间复制数据耗时很长,至少与大多数计算指令相比如此。上面的add_matrices函数可能会花更多时间在HBM和VMEM之间复制数据,而实际上进行加法运算的时间反而较少。

考虑到这两个限制,我们需要重新考虑从TPU中获取性能的策略。

入门:流水线处理#

流水线处理我们的计算提供了一种同时解决内存容量和带宽限制的方法。我们所说的流水线处理是什么?

目标是:并行地从 HBM 和 VMEM 复制数据 同时 利用我们的计算单元。简单地说,这很困难,因为在我们上面的程序中,我们在开始进行计算之前复制了 xy所有 数据,这在复制和计算之间产生了依赖关系。

然而,如果我们能够将计算分成几个子计算(例如,我们在添加两个矩阵时,可以将其表示为将原始矩阵的“块”相加),我们现在可以将其中一个子计算的复制与另一个的计算重叠。让我们通过一个简单的例子来说明:

假设我们将数组 xy 分成 x1, x2y1, y2(例如,沿着首轴分割,导致每个输入生成两个 (256, 512) 的数组)。我们现在可以执行以下流水线计算。

  1. x1y1 复制到 VMEM 中。

  2. 开始将 x2y2 复制到 VMEM 中。

  3. 从 VMEM 中加载 x1, y1 到 VREGs。

  4. 使用计算单元执行 z1 = x1 + y1

  5. z1 存储到 VMEM 中。

  6. 开始将 z1 从 VMEM 复制回 HBM。

  7. 等待 x2, y2 被复制到 VMEM 中。

  8. 从 VMEM 中加载 x2, y2 到 VREGs。

  9. 使用计算单元执行 z2 = x2 + y2

  10. z2 存储到 VMEM 中。

  11. 等待 z1 被复制到 HBM 中。

  12. 开始将 z2 从 VMEM 复制回 HBM。

  13. 等待 z2 被复制到 HBM 中。

在这里,每当我们进行计算时,我们都在异步复制一些数据。这意味着花在复制上的时间不会被浪费。

确定流水线计算效率的两个最重要的数字是:a) 我们需要执行多少次浮点运算(FLOPs),b) 我们需要复制多少字节以执行该计算。这两个数字的比率(FLOPs/内存使用量)被称为操作的 算术强度,它决定了我们的流水线是受计算限制还是受内存限制。

Pallas中的流水线处理#

如何在 Pallas 中实现类似上述的管道? 这似乎是一个复杂的异步数据操作序列和内核执行,手动实现将会非常困难。 别担心!Pallas 提供了一个 API,使我们能够以不需要过多样板代码的方式表达管道,即通过 gridBlockSpec

在上述管道示例中,我们多次执行相同的逻辑:步骤 3-5 和 8-10 都执行相同的操作,只是输入不同。 jax.experimental.pallas.pallas_call() 提供了一种通过使用 grid 参数多次执行内核的方法。 参见 grid,即循环中的内核

我们还使用 jax.experimental.pallas.BlockSpec 来指定如何构造每次内核调用的输入。 参见 BlockSpec,又名如何分割输入

在上述管道示例中,我们有形状为 (512, 512) 的数组,并将它们沿着首个维度拆分成两个形状为 (256, 512) 的数组。 在这个管道中,我们的 BlockSpec.block_shape 将是 (256, 512)。 在第一次迭代中,我们希望选择 x1,而在第二次迭代中,我们希望使用 x2。 这可以用以下的 index_map 来表达:

def x_index_map(i):
  return (i, 0)

然后我们构造 BlockSpec

block_spec = pl.BlockSpec((256, 512), x_index_map)

yzBlockSpec 将与 xBlockSpec 相同。

综合起来#

我们通过 gridin_specsout_specs 将这些参数提供给 pallas_callin_specs 对应位置参数的元组,out_specs 对应输出)。

def add_matrices_pipelined(x: jax.Array, y: jax.Array) -> jax.Array:
  block_spec = pl.BlockSpec((256, 512), lambda i: (i, 0))
  return pl.pallas_call(
      add_matrices_kernel,
      out_shape=x,
      in_specs=[block_spec, block_spec],
      out_specs=block_spec,
      grid=(2,)
  )(x, y)

add_matrices_pipelined(x, y)
Array([[2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       ...,
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.]], dtype=float32)

我们只在原始函数中增加了少量代码以添加自动流水线,但 BlockSpecgrid 进行了大量的繁重工作!

它是如何工作的呢?实际上,BlockSpec 提供了足够的信息来开始从 HBM 预取输入块到 VMEM。例如,如果我们正在开始 grid 的迭代 i,我们可以将 i + 1 传递给 index_map 函数以获取下一次迭代所需的块。然后,我们可以为这些块开始异步复制。对于输出,我们可以在开始当前迭代的输出复制之前等待前一次迭代的输出被复制。

参数化管道#

在我们的内核中,对块形状进行参数化是很常见的。块大小可能是在优化 Pallas 内核性能时最重要的参数!它们让我们能够控制流水线(例如,选择更小的块会增加我们流水线循环的迭代次数,而每次迭代的工作量则较少)。

此外,我们还可以在第二维度上对输入和输出进行划分(目前我们只是在第一维进行划分)。让我们编写一个更通用的内核,以处理这两个特性。

def add_matrices_pipelined_2d(
    x: jax.Array, y: jax.Array, *, bm: int = 256, bn: int = 256
) -> jax.Array:
  m, n = x.shape
  block_spec = pl.BlockSpec((bm, bn), lambda i, j: (i, j))
  return pl.pallas_call(
      add_matrices_kernel,
      out_shape=x,
      in_specs=[block_spec, block_spec],
      out_specs=block_spec,
      grid=(m // bm, n // bn),
  )(x, y)

np.testing.assert_array_equal(
    add_matrices_pipelined_2d(x, y, bm=256, bn=256), x + y
)
np.testing.assert_array_equal(
    add_matrices_pipelined_2d(x, y, bm=128, bn=128), x + y
)
np.testing.assert_array_equal(
    add_matrices_pipelined_2d(x, y, bm=512, bn=512), x + y
)

处理缩减#

如何使用 pallas_call 实现类似 jnp.sum 的功能? 具体来说,我们希望在减小维度上进行流水线处理。

以将形状为 (8, 512, 512) 的数组减少到形状为 (512, 512) 为例。

x = jnp.ones((8, 512, 512))
jnp.sum(x, axis=0)
Array([[8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.],
       ...,
       [8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.]], dtype=float32)

要使用 pallas_call 实现这一点,我们可以使用大小为 (8,) 的网格,并在每次迭代 i 中将 x[i] 加载到 VMEM 中。 然后我们可以将 x[i] 添加到一个输出 VMEM 缓冲区。我们先简单实现这一点。

# 警告:此实现有误!

def naive_sum_kernel(x_ref, o_ref):
  o_ref[...] += x_ref[...]

def naive_sum(x: jax.Array) -> jax.Array:
  grid, *out_shape = x.shape
  return pl.pallas_call(
      naive_sum_kernel,
      grid=grid,
      # `block_shape` 中的 None 表示我们选择大小为 1 的尺寸并将其压缩掉。
      in_specs=[pl.BlockSpec((None, *out_shape), lambda i: (i, 0, 0))],
      out_specs=pl.BlockSpec(out_shape, lambda i: (0, 0)),
      out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype),
  )(x)
naive_sum(x)
Array([[9., 9., 9., ..., 9., 9., 9.],
       [9., 9., 9., ..., 9., 9., 9.],
       [9., 9., 9., ..., 9., 9., 9.],
       ...,
       [9., 9., 9., ..., 9., 9., 9.],
       [9., 9., 9., ..., 9., 9., 9.],
       [9., 9., 9., ..., 9., 9., 9.]], dtype=float32)

注意我们是如何设置 BlockSpec 的:我们将 (512, 512) 维度的完整内容加载到 VMEM(这里没有管道化),但在 index_map 中每次迭代选择 x 的第 i 维。我们在块形状中对该维度使用 None,这表示我们选择的是 x 中的一个单一维度,而我们希望在内核中将其压缩掉。因此,x_ref 在 VMEM 中也具有 (512, 512) 的形状。

out_spec 使用 lambda i: (0, 0) 作为其 index_map,这表示在整个管道过程中 o_ref 是不变的。这意味着我们可以通过读写它来更新它的值。或者说可以这样做吗?实际上,有一个问题:o_ref 最初是垃圾,这意味着我们将累积在垃圾中。这将导致整体函数输出不正确的值!

因此,每当我们在内核中进行归约时,我们需要确保初始化存储归约值的 Ref。我们可以通过在迭代为 0 时有条件地向 out_ref 写入一个值来完成这一点。我们可以使用帮助函数 pl.when,这是一个围绕 jax.lax.cond 的便利包装,以及 pl.program_id,它查询我们在网格轴中的迭代次数。

def sum_kernel(x_ref, o_ref):
  @pl.when(pl.program_id(axis=0) == 0)
  def _():
    o_ref[...] = jnp.zeros_like(o_ref)

  o_ref[...] += x_ref[...]

def sum(x: jax.Array) -> jax.Array:
  grid, *out_shape = x.shape
  return pl.pallas_call(
      sum_kernel,
      grid=grid,
      # `block_shape` 中的 None 表示我们选择大小为 1 的尺寸并将其压缩掉。
      in_specs=[pl.BlockSpec((None, *out_shape), lambda i: (i, 0, 0))],
      out_specs=pl.BlockSpec(out_shape, lambda i: (0, 0)),
      out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype)
  )(x)

sum(x)
Array([[8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.],
       ...,
       [8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.]], dtype=float32)

这个 sum 函数现在输出了正确的值!

关于 Pallas 中的归约还有最后一点需要注意:它们必须在我们网格的最小维度(最右侧维度)进行(在上述示例中,我们的网格是一维的,所以我们在其最小维度上进行归约)。这是因为 Pallas 通过 BlockSpecgrid 和内核函数生成的管道不从 HBM 读取输出。一旦您将输出值写回 HBM,就无法再次访问它。因此,您无法在有任何重新访问的网格维度上进行归约,因此所有归约都需要在最右侧维度上进行。

Megacore配置中的TPU#

一些TPU芯片有两个TensorCore,但对JAX用户而言,表现为一个设备。这称为“巨核”(megacore)。单独的TensorCore拥有各自独立的VMEM、VREG、SMEM、SREG以及计算单元,但共享HBM

TPU内存空间卡通图(巨核)

从概念上讲,巨核中的TPU表现得像非常简单的GPU,即它们只有两个线程。我们如何修改我们的内核以同时利用两个TensorCore?

基本的思路是,如果我们在计算中有令人尴尬的并行维度,我们可以将这些维度分配到TensorCore上。我们可以通过给pallas_call提供一个名为dimension_semantics的注解来表明哪些维度是可以并行的。

def add_matrices_pipelined_megacore(x: jax.Array, y: jax.Array) -> jax.Array:
  block_spec = pl.BlockSpec((256, 512), lambda i: (i, 0))
  return pl.pallas_call(
      add_matrices_kernel,
      out_shape=x,
      in_specs=[block_spec, block_spec],
      out_specs=block_spec,
      grid=(2,),
      compiler_params=pltpu.TPUCompilerParams(dimension_semantics=("parallel",))
  )(x, y)

x, y = jnp.ones((512, 512)), jnp.ones((512, 512))
add_matrices_pipelined_megacore(x, y)
Array([[2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       ...,
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.],
       [2., 2., 2., ..., 2., 2., 2.]], dtype=float32)

dimension_semantics 应该是一个与 grid 长度相同的元组,每个条目要么是 "parallel",要么是 "arbitrary""parallel" 表示对 Pallas 来说,与该维度对应的 for 循环的迭代可以独立执行,而不会影响程序的正确性。"arbitrary" 表示Pallas无法对该网格维度做任何假设,因此无法进行并行化。

通过指定 dimension_semantics,我们现在可以在每个 TensorCore 上同时执行内核。Pallas 将自动处理网格的拆分。

请注意,Megacore 目前仅在 TPU v4 和 TPU v5p 上可用。在其他平台上提供 dimension_semantics 注释是没有效果的,但 指定它将导致仅使用一个 TensorCore(即使有多个可用)。

结论#

在本指南中,我们介绍了如何使用 pallas_callgridBlockSpec 来表达 TPU 管道。我们讨论了如何通过多维网格来表达嵌套循环,以及如何在减少操作开始时初始化我们的累加器来处理归约。我们还学习了如何通过为内核添加注释来处理 Megacore。

留给读者的练习:

  • 尝试实现一个 sum 内核,该内核也对其他维度进行管道处理

  • add 内核和 sum 内核添加 Megacore 支持。