Pallas 设计#

在本文件中,我们解释了最初的 Pallas 设计。这是一些早期设计决策的快照,自那时起 Pallas 的具体 API 可能已经发生了变化。

介绍#

JAX 被用于各种不同的工作负载,从大规模机器学习到科学计算。JAX 的成功故事同样也是 XLA 的成功故事,XLA 是 JAX 主要针对的编译器——XLA 为加速器编译 JAX 程序,并使 JAX 能够扩展到最大的机器学习模型。JAX 在 XLA 的表示形式 HLO 中描述逻辑计算。HLO 描述了计算如何逻辑上发生,但不是物理上。给定一个逻辑 HLO 计算,XLA 决定该计算如何物理上执行。对于广泛的机器学习应用,XLA 很好地编译了用户程序,但不可避免地有些用户会遇到 XLA 的限制。在这些情况下,我们需要提供一个“逃生舱”,允许专家编写手动调优的内核,这些内核在那个时间点上优于 XLA。此外,机器学习系统研究的进展需要一些时间才能被纳入 XLA,而用户通常希望提前运行它们。随着时间的推移,编译器可以纳入通过手动调优内核实验证明的优化。

XLA 确实提供了 CustomCall 机制作为逃生舱,但它要求用户编写 C++,并且在 GPU 上要求用户学习 CUDA 编程模型。CUDA 编程模型对于许多机器学习 GPU 内核(如矩阵乘法)来说,可以说是过于底层,即使是专家用户也会在使用 CUDA 实现高效的矩阵乘法或多头注意力时遇到困难。不仅如此,JAX 用户通常熟悉 Python 和 NumPy 风格的数组编程,这不需要编写任何 C++ 或考虑 GPU 并行性。所有流行的机器学习框架都共享这一理念:使用高层次操作(如 matmulconvolution)操作(通常是)数组。不幸的是,这意味着通过 CustomCall 实现自定义操作是一项巨大的投资,可能涉及学习 C++ 和/或 GPU 编程。

Triton,由OpenAI构建和维护的GPU编译器,已经在ML编译器世界中引起了轰动。Triton提供了两全其美的方案:一个基于数组的GPU内核编程模型。Triton是PyTorch 2.0中torch.compile的主要代码生成路径,通过Torch Inductor库实现。Triton在更易访问的编程模型的名义下,积极隐藏了GPU编程的某些方面,该模型可以从Python中使用,并从更高层次的表示生成优化代码。虽然GPU比Triton提供的更灵活,但在ML领域,Triton似乎对许多应用来说已经足够表达。

在本文件中,我们描述了 Pallas,这是 JAX 的一个扩展,它使用类似 Triton 的模型为 GPU 和 TPU 启用内核编程。基于 JAX 的内核语言提供了几个优势:

  • 尽管 Triton 向用户暴露了一个类似 TPU 的编程模型,即在 L1 缓存中为数组的块编写程序,但它足够专门化于 GPU,以至于我们无法直接为 TPU 编译 Triton。例如,Triton 提供了专门用于处理不一定在 TPU 上有意义的并行写入的原子操作。一个更高层次的前端可以在抽象平台细节的同时,只暴露基于块的编程模型。因此,内核将能够在不同的硬件平台上移植。

  • JAX 作为一个基于追踪的数值计算前端,既成熟又广泛使用。通过将内核编程语言嵌入到 JAX 本身,我们可以重用 JAX 的追踪基础设施,并提供一个类似于 NumPy 的前端,这对用户来说已经很熟悉了。

  • JAX 变换是其成功的关键,允许用户表达简单的程序,但通过变换实现复杂的功能。我们可以利用相同的变换(vmap、jvp 等)来变换用户编写的内核。

开放的问题是:JAX 是否适合作为内核语言?我们认为如此。Triton 证明了数组编程语言可以实际用于编写 GPU 内核,而 JAX 正是这样一种语言。JAX 还证明了它是一个灵活的编译器前端和程序转换工具。

我们将Pallas描述如下:我们首先描述了扩展JAX以支持编写自定义内核的方式。然后,我们展示了如何将Pallas降低到Triton和Mosaic。最后,我们通过JAX变换描述了现有和潜在的Pallas内核变换方法。

Pallas 降低路径 Pallas 降低路径的可视化

Pallas: 扩展 JAX 以支持内核#

我们想要强调的关键点是,Pallas 只是 JAX,带有一些扩展:

  1. 用户现在在他们的 JAX 代码中使用称为 Ref 的引用类型。这使用户能够更精确地控制内存访问和布局,使 JAX 更接近物理布局。

  2. 用户使用 JAX 原语的一个子集以及一组特定于 Pallas 的原语来编写他们的 JAX 程序。

  3. 用户通过一个特殊的 pallas_call 高阶函数将他们的 Pallas 内核嵌入到外部 JAX 程序中,该函数在一个映射中执行内核。它类似于 pmapshard_map,除了具有对共享内存的引用。

我们将通过示例逐一介绍这三种扩展。

请注意,这些API仍在实验阶段,可能会发生变化。

参考类型#

让我们来看一个用于向量相加的 Pallas 程序示例:

import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl

def add_kernel(x_ref, y_ref, o_ref):
  # In this code, `x_ref`, `y_ref` and `o_ref` are (8,)-shaped `Ref`s
  x = x_ref[:]
  y = y_ref[:]
  o_ref[:] = x + y
x, y = jnp.arange(8), jnp.arange(8, 16)
add = pl.pallas_call(add_kernel, out_shape=jax.ShapeDtypeStruct((8,), jnp.int32))
add(x, y)

与普通的 JAX 程序不同,add_kernel 不接收不可变的数组参数。相反,它提供的是可以读取和使用类似 NumPy 的语法就地更新的引用。Ref 不是 Pallas 特有的概念——它们是为了表示有状态计算而引入 JAX 的。然而,在编写操作可变内存的内核时,我们也可以利用它们。

Pallas 内核不仅接收与内核输入相对应的 Ref,还接收与输出相对应的 Ref(通过 pallas_call 中的 out_shape 指定)。Ref 是特殊类型,不能在不先读取的情况下传递到 JAX 原语的通常集合中。当你从 Ref 读取时,你会得到一个 JAX Array 类型,并且你必须将一个 Array 写入 Ref

读取/写入引用#

Ref 读取相当于将数组加载到内存层次结构的最低级别(GPU 上的 L1 缓存和 TPU 上的向量寄存器)。写入 Ref 与此类似。

def f(x_ref, o_ref):
  # Using vanilla Python indexing
  x = x_ref[0, 2:5, :]
  # Or via Numpy advanced int indexing
  o_ref[jnp.arange(3), :] = x

# Note that in order to use NumPy advanced int indexing, you need to broadcast the indices against each other into the desired multidimensional shape:
def f(x_ref):
  # Assume x_ref is (8, 4) and we want to read out a (2, 3) slice
  x = x_ref[jnp.arange(2)[..., None], jnp.arange(3)[None, ...]]

可以通过类似 __setitem__ 风格的索引方式写入 Ref

其他形式的索引(例如,动态切片)可以通过 pallas.loadpallas.store 完成,这是为简化从内存加载/存储而设计的新 JAX 原语。我们将在后面讨论这些新原语。

使用新的 Pallas 原语扩展 JAX#

因为 JAX 的设计考虑了 HLO,JAX 原语集与 HLO 操作集紧密对应。针对一个新的编译器(例如 Triton 或 Mosaic),我们可能需要为新编译器补充特定的 JAX 原语。同时,我们可能无法降低所有 JAX 原语,因此需要将其限制在一个子集内。

由于 Pallas 最初是针对 Triton 设计的,我们提供了一组针对 Triton 编程模型的新原语。正如我们稍后将展示的,我们也可以将这些原语降低为 Mosaic。

pallas.loadpallas.store#

pallas.loadpallas.store 是允许从内存加载和存储到内存的原语。与 __getitem____setitem__ 不同,它们以更冗长为代价提供了更大的灵活性。具体来说,你可以使用 pallas.dynamic_slice(简称 pallas.ds)构造(这可能应该被上游到 JAX 中,以便与 Ref 的 __getitem____setitem__ 一起使用)。

def f(x_ref, o_ref):
  # Reading from memory via pallas.load
  x = pl.load(x_ref, (0, slice(2, 5), slice(None)))
  # Using integer indexing automatically broadcasts
  x = pl.load(x_ref, (0, 2 + jnp.arange(3), slice(None)))
  # You can also use `pl.dynamic_slice` (`pl.ds` for short) objects as well
  pl.store(o_ref, (0, pl.ds(start=2, size=3), slice(None)), x)

pallas.loadpallas.store 也通过 mask 参数支持掩码操作。

def f(x_ref, o_ref):
  # Reading from memory via pallas.load
  idx = jnp.arange(8)
  mask = idx < 5
  x = pl.load(x_ref, (idx,), mask=mask, other=float('-inf'))

在进行越界加载/存储时,掩码操作非常重要。掩码的操作语义可以由编译器决定(如果我们正确理解文档,Triton 会在掩码时避免从/向内存读取/写入)。

pallas.program_idpallas.num_programs#

正如我们很快会看到的,我们将多次执行相同的 Pallas 内核(根据后端的不同,可能是并行或流水线方式)。这些新的原语告诉我们“在执行内核的哪个位置”。

pallas.program_id 接受一个轴参数,该参数告诉我们当前内核正在多维网格的哪个轴索引中执行(类似于CUDA编程中的threadIdjax.pmap中的lax.axis_index)。请注意,我们目前从Triton借用了“程序”术语,未来我们可能会将其更改为JAX用户更熟悉的术语。

def f(x_ref, o_ref):
  i = pl.program_id(axis=0)  # execution index in the first axis of the grid
  o_ref[i] = jnp.exp(x_ref[i])

pallas.num_programs 也接受一个轴,并返回该轴的网格大小。

需要注意的是,虽然 program_idnum_programs 是 Triton 特有的术语,但它们很容易被泛化,以便在 TPU 上也能有意义。

在 Pallas 中使用 JAX 原语的子集#

因为我们编写的是内核,而不是高级HLO程序,所以某些JAX原语可能无法在我们的底层基质中高效表示。然而,我们知道我们可以支持大多数逐元素操作、简单的点积和JAX控制流。

虽然我们还没有完全规划出所有可以在 Pallas 内核中支持的 JAX 原语,但我们可以确定一些不容易降低或不太可能有用的原语:

  • conv_general - 卷积通常在底层硬件中并不作为原语提供。

  • gather/scatter - 底层编译器可能不支持非连续内存的读取和写入

使用 pallas_call 执行 Pallas 内核#

既然我们已经编写了Pallas内核(即带有Refs和额外Pallas原语的JAX),我们如何在GPU或TPU上执行它们呢?我们使用pallas_call,这是一个高阶函数(类似于jax.jitjax.pmap),用于执行内核。

pallas_call 的签名如下:

def pallas_call(
    kernel: Callable,
    out_shape: Sequence[jax.ShapeDtypeStruct],
    *,
    in_specs: Sequence[Spec],
    out_specs: Sequence[Spec],
    grid: Optional[Tuple[int, ...]] = None) -> Callable:
  ...

当我们为 pallas_call 提供一个内核时,我们提供了额外的信息。首先是 out_shape,它告诉内核输出是什么样子的(pallas_call 将传递一个与这些输出对应的 Ref 到内核中以供写入)。其余的信息(in_specsout_specsgrid)是关于内核如何在加速器上调度执行的信息。

pallas_call 的(大致)语义如下:

def pallas_call(kernel, out_shape, *, in_specs, out_specs, grid):
  def execute(*args):
    outputs = map(empty_ref, out_shape)
    grid_indices = map(range, grid)
    for indices in itertools.product(*grid_indices): # Could run in parallel!
      local_inputs = [in_spec.transform(arg, indices) for arg, in_spec in
                      zip(args, in_specs)]
      local_outputs = [out_spec.transform(arg, indices) for arg, out_spec  in
                       zip(outputs, out_specs)]
      kernel(*local_inputs, *local_outputs) # writes to outputs
  return execute

具体来说,pallas_call 将对网格迭代空间进行“循环”,对通过 in_specsout_specs 指定的输入和输出应用转换。在每次迭代中,内核将在转换后的输入和输出上调用。请注意,迭代空间上的“循环”可以并行执行(例如在GPU上)。pallas_call 不保证迭代空间上循环迭代的顺序,只是确保迭代空间的每个成员都会被循环遍历。像 Triton 和 Mosaic 这样的编译器将具有与网格相关的更具体的操作语义。

转换函数#

pallas_callin_specsout_specs 参数允许以某种方式转换输入和输出。Pallas 目前提供的两种选项是恒等变换(输入和输出保持不变),以及 BlockSpec,它根据循环索引获取 Ref 的固定大小切片。

一个 BlockSpec 接受一个 index_map 函数和一个 block_shape。从逻辑上讲,它接受一个数组,并沿着每个轴将其切片为 block_shape 大小的块。index_map 函数接受循环索引(来自网格索引集),并将它们映射到块索引。转换函数将 Ref 转换为相应块中 Ref 的逻辑视图。当我们在 block_shape 的一个条目中指定 None 时,这对应于“映射”该维度,从内核中的块中移除它。

class BlockSpec:
  index_map: Callable[[Tuple[Int, ...]], Tuple[Int, ...]]
  block_shape: Tuple[Optional[int], ...]

  def transform(self, ref, *loop_indices):
    block_indices = self.transform_function(loop_indices)
    # Returns a view of `ref` starting at `block_indices` of shape self.block_shape
    ...

我们也可以设想其他与 pallas_call 一起使用的 Spec,例如,对应于重叠窗口的 Spec,用于实现卷积。

Pallas 作为前端的即时优势#

通过为内核编写提供JAX前端,我们可以立即获得一些好处。

更灵活的前端#

首先,JAX 用户已经习惯了使用 JAX 及其基于追踪的变换进行编程的好处(及其限制)。这意味着用户在编写 Pallas 内核时可以使用闭包和其他熟悉的 Python 结构。这与现有的基于 AST 解析的 Triton 前端或 Mosaic 的 MLIR 构建器不同。例如,这使得 Pallas 比 Triton 更易于模板化。

请参阅此示例,了解我们如何在Python中使用高阶函数来模板化内核。

def make_kernel(eltwise_kernel):
  def add(x_ref, y_ref, o_ref):
    x = pl.load(x_ref, ())
    y = pl.load(y_ref, ())
    pl.store(o_ref, (), eltwise_kernel(x + y))
  return add

kernel1 = make_kernel(lambda x: x * 2)
kernel2 = make_kernel(jnp.exp)

pl.pallas_call(kernel1, out_shape=x, grid=1)(1., 1.)
pl.pallas_call(kernel2, out_shape=x, grid=1)(1., 1.)

仿真模式#

通过将内核表示为带有JAX原语和一些新的Pallas原语的程序,我们还可以直接将Pallas程序降低到StableHLO,并使用XLA编译/执行它们。具体来说,一个 pallas_call 可以实现为一个在网格上的 lax.scan。这使我们能够在任何支持XLA的平台(甚至是CPU!)上开发GPU或TPU内核,并使用JAX/XLA调试工具(如 jax.debug.print)进行调试。我们还可以使用更可靠和经过更好测试的XLA数值来验证Triton和Mosaic编译器的正确性。也可以想象通过扰动 scan 顺序来模拟在GPU上发生的并行读写。

示例#

add#

我们修改 add_kernel 示例,使用 BlockSpec 对 (2,) 大小的块进行操作。

def add_kernel(x_ref, y_ref, o_ref):
  # In this code, `x_ref`, `y_ref` and `o_ref` are (2,)-shaped `Ref`s
  x = x_ref[:]
  y = y_ref[:]
  o_ref[:] = x + y
x, y = jnp.arange(8), jnp.arange(8, 16)
add = pl.pallas_call(
    add_kernel,
    out_shape=jax.ShapeDtypeStruct((8,), jnp.int32),
    in_specs=[
        pl.BlockSpec((2,), lambda i: i),
        pl.BlockSpec((2,), lambda i: i)
    ],
    out_specs=pl.BlockSpec((2,), lambda i: i),
    grid=(4,))
add(x, y)

模板化的矩阵乘法#

在这个例子中,我们通过对输入数组的行和列块进行展开的累加来计算输出块。我们使用高阶函数将激活函数内联到内核的主体中,以便我们可以发出一个融合的内核。

def matmul_kernel(x_ref, y_ref, o_ref, *, activation, block_k):
  acc = jnp.zeros((x_ref.shape[0], y_ref.shape[1]), jnp.float32)
  for k in range(x_ref.shape[1] // block_k):
    x = x_ref[:, k*block_k:(k+1)*block_k]
    y = y_ref[k*block_k:(k+1)*block_k, :]
    acc += x @ y
  o_ref[:, :] = activation(acc).astype(o_ref.dtype)

x, y = jnp.ones((512, 256)), jnp.ones((256, 1024))
block_shape = 128, 256, 128

@partial(jax.jit, static_argnames=["block_shape", "activation"])
def matmul(x, y, *, block_shape, activation):
  block_m, block_n, block_k = block_shape
  fused_matmul = pl.pallas_call(
      partial(matmul_kernel, block_k=block_k, activation=activation),
      out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1],), jnp.float32),
      in_specs=[
          pl.BlockSpec((block_m, x.shape[1]), lambda i, j: (i, 0)),
          pl.BlockSpec((y.shape[0], block_n), lambda i, j: (0, j))
      ],
      out_specs=pl.BlockSpec((block_m, block_n), lambda i, j: (i, j)),
      grid=(4, 4),
  )
  return fused_matmul(x, y)

z = matmul(x, y, block_shape=block_shape, activation=jax.nn.gelu)

降低Pallas#

在用户表达他们的 Pallas 内核后,我们会根据目标后端将其转换为不同的表示形式。在 GPU 上,我们将 Pallas 转换为 Triton IR,而在 TPU 上,我们将 Pallas 转换为 Mosaic。

将Pallas降低到Triton以用于GPU#

将Pallas降低到Triton很容易,因为Pallas在设计时就已经考虑到了Triton作为目标语言。Pallas和Triton之间的主要区别在于,Triton没有BlockSpec的概念,并且在进行内存加载和存储时使用指针而不是索引。

Triton 在其语言中支持将指针作为数组元素类型,并且在 Triton 中你可以从指针数组中加载数据或将数据存储到指针数组中。在 Pallas 中,当给定一个形状为 (4, 5)Refx_ref,然后执行类似 x_ref[3, 2] 的操作时,我们需要将其转换为计算指向 x_ref 中适当行主序位置的 Triton 指针(即执行 5 * 3 + 2 * 1)。同样地,当我们将切片转换为 Triton 时,例如 x_ref[4, :],我们需要生成一个指针数组 5 * 4 + jnp.arange(3)

除此之外,降低到 Triton 是相当直接的。JAX 点积可以降低到 Triton 点积,JAX 一元原语降低到它们的 Triton 等价物。Triton 的原子操作通过新的 Pallas 原子原语降低。

将 Pallas 降低到 TPU 的 Mosaic#

Mosaic 消耗(主要是)标准方言 MLIR 并生成 LLO 以编译为 TPU。Pallas 可以通过将 JAX 原语转换为 MLIR(主要是 vectorarith 方言)来降低为 Mosaic。BlockSpec 可以转换为流水线调度(即 Mosaic 中的 transform_func)。

转换 Pallas#

一个自然的问题是 JAX 变换如何与 Pallas 内核交互?主要有两种方式:Pallas 内核内部的变换和 Pallas 内核外部的变换。

在 Pallas 内核中的转换实际上应该“只是工作”,只要我们能够降低转换后的代码。例如,我们可以在 JAX 内核中使用 jax.grad(jnp.sin)(...),因为我们能够将 cos 降低到 Triton 和 Mosaic。然而,我们可能无法降低 jax.vmap(lax.dynamic_slice),因为它可能会变成我们无法降低的 gather。

从外部 JAX 程序转换 Pallas 内核可能是更有趣的情况。我们如何处理诸如 vmap(pallas_call)grad(pallas_call) 之类的事情?

vmap-of-pallas_call#

vmap 自动向量化 JAX 程序。虽然内核编写者可能希望精确控制批处理内核与其非批处理变体的行为差异,但我们可以在提供 jax.custom_vmap 自定义机制的同时,为 pallas_call 提供一个合理的默认 vmap 规则。当 pallas_callvmap 处理时,我们增强 pallas_call 以增加一个对应于新批处理维度的额外网格维度,并转换 BlockSpec 以处理沿该维度的索引。

grad-of-pallas_call#

pallas_callgrad 实现了内核的自动微分。jax.grad 分解为三种不同变换的应用:jvppartial_evaltranspose。原则上,在为 pallas_call 实现这些规则时,我们可以重用 JAX 的大部分基础设施(因为它与现有的 JAX 高阶原语行为非常相似)。

然而,内核的自动微分可能会由于内存访问的转置方式而导致性能下降。如果我们编写一个具有重叠并行读取和不相交并行写入的GPU内核,我们会自动将其转置为一个具有重叠并行写入(这在原子操作时很慢)和不相交并行读取的内核。为了发出一个更好地利用共享内存并行性的内核,我们需要重新排序循环并改变内核的矢量化方式。不幸的是,Pallas中没有适合这种操作的程序表示。一个潜在的方向是探索一种不同的表示方式,也许类似于Dex中的表示方式,来自动高效地微分内核。我们也可以看看Enzyme是如何处理这个问题的。然而,Pallas内核的自动微分对于一类能够高效转置的内核(例如逐元素内核)仍然是有用的。

尽管如此,jax.custom_vjp 是一个可行的逃生舱口,用于表达与 jax.grad 一起工作的 Pallas 内核。

其他转换#

我们可以想象其他应用于 Pallas 内核的 JAX 变换,这些变换我们尚未明确探索。例如,checkify 是一个执行功能性错误处理的 JAX 变换。我们可以想象将 checkify 与 pallas_call 结合使用,以便从指示是否发生越界访问或产生 NaN 的 GPU 内核中导出错误代码。

另一个可以整合的潜在转换是 custom_partitioning,以启用可自动分区的内核与 pjit 一起使用。