Pallas 快速入门#

Pallas 是 JAX 的一个扩展,能够为 GPU 和 TPU 编写自定义内核。 Pallas 允许您使用相同的 JAX 函数和 API,但在 较低 的抽象级别上进行操作。

具体而言,Pallas 要求用户考虑内存访问以及如何在硬件加速器的多个计算单元之间划分计算。 在 GPU 上,Pallas 降低到 Triton,而在 TPU 上,Pallas 降低到 Mosaic。

让我们深入一些示例。

注意:Pallas 仍然是一个实验性 API,您可能会因更改而遭遇问题!

Pallas中的你好世界#

from functools import partial

import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp
import numpy as np

我们将首先在Pallas中编写“你好,世界”,这是一个用于添加两个向量的内核。

def add_vectors_kernel(x_ref, y_ref, o_ref):
  x, y = x_ref[...], y_ref[...]
  o_ref[...] = x + y

Ref 类型

让我们稍微剖析一下这个函数。与您可能编写的大多数 JAX 函数不同,它并不接受 jax.Array 作为输入,也不返回任何值。相反,它接受 Ref 对象作为输入。请注意,我们也没有任何输出,但我们给定了一个 o_ref,它对应于所需的输出。

Ref 中读取

在函数体内,我们首先从 x_refy_ref 中读取,使用 [...] 表示(省略号意味着我们在读取整个 Ref;或者我们也可以使用 x_ref[:])。以这种方式从 Ref 中读取会返回一个 jax.Array

写入 Ref

然后我们将 x + y 写入 o_ref。在历史上,JAX 并不支持变异 – jax.Array 是不可变的!Ref 是新的(实验性)类型,允许在某些情况下进行变异。我们可以将写入 Ref 理解为改变其底层缓冲区。

所以我们编写了一个我们称之为“内核”的程序,我们定义它为一个将作为加速器上一个原子执行单元运行的程序,而无需与主机进行任何交互。 我们如何从 JAX 计算中调用它? 我们使用 pallas_call 高阶函数。

@jax.jit
def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
  return pl.pallas_call(
      add_vectors_kernel,
      out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)
  )(x, y)
add_vectors(jnp.arange(8), jnp.arange(8))
Array([ 0,  2,  4,  6,  8, 10, 12, 14], dtype=int32)

pallas_call 将 Pallas 内核函数提升为可以作为更大 JAX 程序一部分调用的操作。但为了做到这一点,它需要更多的细节。在这里,我们指定 out_shape,这是一个具有 .shape.dtype(或其列表)的对象。out_shape 决定了我们 add_vector_kernelo_ref 的形状/dtype。

pallas_call 返回一个接受和返回 jax.Array 的函数。

这里实际上发生了什么?

到目前为止,我们已经描述了如何考虑Pallas内核,但我们实际完成的是我们正在编写一个非常接近计算单元执行的函数。

在GPU上,x_ref对应于高带宽内存(HBM)中的一个值,当我们执行x_ref[...]时,我们将该值从HBM复制到静态RAM(SRAM)(一般来说,这是一个昂贵的操作!)。 然后我们使用GPU向量计算来执行加法,将结果值从SRAM复制回HBM。

在TPU上,我们做了稍微不同的事情。在内核执行之前,我们将值从HBM获取到SRAM。因此,x_ref对应于SRAM中的一个值,当我们执行x_ref[...]时,我们将该值从SRAM复制到寄存器。 然后我们使用TPU向量计算来执行加法,将结果值再次复制回SRAM。在内核执行后,SRAM中的值被复制回HBM。

我们正在编写特定于后端的Pallas指南,敬请期待!

Pallas 编程模型#

在我们的“你好,世界”示例中,我们编写了一个非常简单的内核。 它利用了我们的8大小数组可以很好地适应硬件加速器的SRAM。 在大多数现实应用中,情况并非如此!

编写Pallas内核的一部分是思考如何处理存储在高带宽内存(HBM,也称为DRAM)中的大数组,并表达在能在SRAM中适应该数组“块”的计算。

通过示例理解网格#

为了自动“划分”输入和输出,您需要向pallas_call提供一个gridBlockSpec

grid是一个整数的元组(例如,()(2, 3, 4)(8,)),指定了一个迭代空间。 例如,一个网格(4, 5)将有20个元素: (0, 0),(0, 1),...,(0, 4),(1, 0),...,(3, 4)。 我们对每个元素运行内核函数一次,这是一种单程序多数据(SPMD)编程风格。

二维网格的可视化

二维网格

当我们向pallas_call提供一个grid时,内核将执行prod(grid)次。每次调用被称为一个“程序”。 要访问当前内核正在执行的程序(即网格的哪个元素),我们使用program_id(axis=...)。 例如,对于调用(1, 2)program_id(axis=0)返回1program_id(axis=1)返回2

这是一个使用gridprogram_id的示例内核。

def iota_kernel(o_ref):
  i = pl.program_id(0)
  o_ref[i] = i

我们现在使用 pallas_call 执行它,并添加一个 grid 参数。

def iota(size: int):
  return pl.pallas_call(iota_kernel,
                        out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),
                        grid=(size,))()
iota(8)
Array([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32)

在GPU上,每个程序都是在独立的线程上并行执行的。因此,我们需要考虑写入高带宽内存(HBM)时的竞争条件。一种合理的方法是以不同的方式编写我们的内核,使得不同的程序写入HBM中的不重叠位置,以避免这些并行写入。另一方面,计算的并行化是我们能够快速执行矩阵乘法等操作的关键。

在TPU上,程序的执行是并行与顺序的结合(取决于架构),因此需要考虑稍微不同的因素。

您可以在 grid,即循环中的内核 中阅读更多详细信息。

通过示例的块规格#

考虑到gridprogram_id,Pallas提供了一个抽象层,以处理许多内核中常见的索引模式。为了建立直觉,让我们尝试实现一个矩阵乘法。

在Pallas中实现矩阵乘法的一个简单策略是递归实现。我们知道我们底层的硬件支持小规模的矩阵乘法(使用GPU和TPU张量核心),因此我们只需将大规模矩阵乘法表示为小规模的矩阵乘法。

假设我们有输入矩阵\(X\)\(Y\),并计算\(Z = XY\)。我们首先将\(X\)\(Y\)表示为块矩阵。\(X\)将具有“行”块,\(Y\)将具有“列”块。

\[\begin{split} \begin{align*} X = \begin{bmatrix} X_0 \\ X_1 \end{bmatrix} \end{align*} \end{split}\]
\[ \begin{align*} Y = \begin{bmatrix} Y_0 & Y_1 \end{bmatrix} \end{align*} \]
\[\begin{split} \begin{align*} Z &= \begin{bmatrix} X_0 \\ X_1 \end{bmatrix} \begin{matrix} \begin{bmatrix} Y_0 & Y_1 \end{bmatrix} \\ ~ \end{matrix} \\ &= \begin{bmatrix} X_0 Y_0 & X_0 Y_1 \\ X_1 Y_0 & X_1 Y_1 \end{bmatrix} \end{align*} \end{split}\]

我们的策略是,因为\(Z\)也是一个块矩阵,我们可以将我们的Pallas内核中的每个程序分配给一个输出块。计算每个输出块对应于在\(X\)的“行”块和\(Y\)的“列”块之间进行更小的矩阵乘法。

为了表达这一模式,我们使用 BlockSpecBlockSpec 为每个输入和输出指定一个块形状,以及一个“索引映射”函数,该函数将一组程序索引映射到一个块索引。

BlockSpec 的可视化`

BlockSpec 的可视化

举一个具体的例子,假设我们想将两个 (1024, 1024) 的矩阵 xy 相乘以生成 z,并希望将计算并行化 4 种方式。我们将 z 分成 4 个 (512, 512) 的块,每个块通过进行 (512, 1024) x (1024, 512) 的矩阵乘法来计算。为了表达这一点,我们首先使用 (2, 2) 的网格(每个程序一个块)。

对于 x,我们使用 BlockSpec((512, 1024), lambda i, j: (i, 0)) – 这将 x 切分成“行”块。要看到这一点,可以看一下程序实例 (1, 0)(1, 1) 如何选择 x 中的 (1, 0) 块。对于 y,我们使用转置版本 BlockSpec((1024, 512), lambda i, j: (0, j))。最后,对于 z,我们使用 BlockSpec((512, 512), lambda i, j: (i, j))

这些 BlockSpec 通过 in_specsout_specs 传递给 pallas_call

有关 BlockSpec 的更多详细信息,请参见 BlockSpec,又名如何分割输入

在底层,pallas_call 会自动将您的输入和输出切分为每个块的 Ref,这些 Ref 将传递给内核。

def matmul_kernel(x_ref, y_ref, z_ref):
  z_ref[...] = x_ref[...] @ y_ref[...]

def matmul(x: jax.Array, y: jax.Array):
  return pl.pallas_call(
    matmul_kernel,
    out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype),
    grid=(2, 2),
    in_specs=[
        pl.BlockSpec((x.shape[0] // 2, x.shape[1]), lambda i, j: (i, 0)),
        pl.BlockSpec((y.shape[0], y.shape[1] // 2), lambda i, j: (0, j))
    ],
    out_specs=pl.BlockSpec(
        (x.shape[0] // 2, y.shape[1] // 2), lambda i, j: (i, j),
    )
  )(x, y)
k1, k2 = jax.random.split(jax.random.key(0))
x = jax.random.normal(k1, (1024, 1024))
y = jax.random.normal(k2, (1024, 1024))
z = matmul(x, y)
np.testing.assert_allclose(z, x @ y)

注意,这是一个非常简单的矩阵乘法实现,但请将其视为各种优化的起点。 让我们为我们的矩阵乘法添加一个额外的功能:融合激活。 这实际上非常简单!只需将一个高阶激活函数传递给内核即可。

def matmul_kernel(x_ref, y_ref, z_ref, *, activation):
  z_ref[...] = activation(x_ref[...] @ y_ref[...])

def matmul(x: jax.Array, y: jax.Array, *, activation):
  return pl.pallas_call(
    partial(matmul_kernel, activation=activation),
    out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype),
    grid=(2, 2),
    in_specs=[
        pl.BlockSpec((x.shape[0] // 2, x.shape[1]), lambda i, j: (i, 0)),
        pl.BlockSpec((y.shape[0], y.shape[1] // 2), lambda i, j: (0, j))
    ],
    out_specs=pl.BlockSpec(
        (x.shape[0] // 2, y.shape[1] // 2), lambda i, j: (i, j)
    ),
  )(x, y)
k1, k2 = jax.random.split(jax.random.key(0))
x = jax.random.normal(k1, (1024, 1024))
y = jax.random.normal(k2, (1024, 1024))
z = matmul(x, y, activation=jax.nn.relu)
np.testing.assert_allclose(z, jax.nn.relu(x @ y))

最后,让我们强调Pallas的一个酷功能:它可以与jax.vmap结合使用! 要将这个矩阵乘法转换为批处理版本,我们只需要使用vmap即可。

k1, k2 = jax.random.split(jax.random.key(0))
x = jax.random.normal(k1, (4, 1024, 1024))
y = jax.random.normal(k2, (4, 1024, 1024))
z = jax.vmap(partial(matmul, activation=jax.nn.relu))(x, y)
np.testing.assert_allclose(z, jax.nn.relu(jax.vmap(jnp.matmul)(x, y)))