使用 Pallas 编写 TPU 内核#
本页重点介绍在尝试在Google TPU上运行Pallas内核时需要注意的细节。首先,TPU后端仍处于实验阶段,只有一部分JAX NumPy会被接受。此外,为TPU编写高性能代码可能需要仔细考虑硬件的原生能力。虽然许多不符合硬件特性的模式会被接受,但它们最终可能需要软件模拟,并可能减慢计算速度。
警告
此功能仍应被视为实验性的,因为工作仍在进行中(特别是改进错误消息)。
备注
虽然这里描述的所有功能都是实验性的,但我们仍然非常认真地维护它们的正确性。因此,在尝试编写TPU内核时,看到“未实现”错误可能并不罕见。但是,如果一个内核被编译器接受,它*必须*返回预期的结果。
如果你看到意外的输出,请将它们与通过 pallas_call
传递 interpret=True
运行的内核结果进行比较。如果结果不一致,请提交一个 错误报告。
什么是 TPU?#
TPU 是 Google 开发的硬件加速器。你可以将 TPU 视为 GPU,但它们专门用于机器学习工作负载。因此,它们的架构有很大不同。然而,我们相信 Pallas 可以让您轻松开始编写 TPU 内核,即使您对底层硬件没有完全了解。话虽如此,对硬件有深入了解无疑会使编写高性能内核变得更加容易。
简而言之,TPU 和 GPU 的主要区别在于 TPU 是顺序机器,具有非常宽的向量寄存器(有点像 CPU!)。同时,它们允许软件在后台调度某些操作,使这些操作与主指令流异步执行。这包括 HBM 内存访问(不能直接发出,而是必须由 DMA 子单元预取到内存层次结构的较低级别)、矩阵乘法(由 MXU 单元支持)或矩阵转置和置换(由 XLU 单元支持)。
如果你对详细了解TPU架构感兴趣,我们推荐阅读多年来发表的一系列论文。虽然其中许多讨论了特定的TPU代际,但所描述的许多想法也适用于后续代际。
值得注意的属性和限制#
BlockSpec
和网格迭代#
BlockSpec
s(参见 BlockSpec,又名如何分割输入)在 Pallas 中的行为通常符合预期——内核体的每次调用都可以访问输入的切片,并用于初始化输出的切片。
备注
- 并非所有块形状都受支持。在TPU上,仅支持秩至少为1的块。
得到支持。此外,你的块形状的最后两个维度必须分别能被8和128整除,或者等于整个数组的相应维度。
Pallas TPU 内核的一个有趣方面是它们处理内存空间的方式:虽然 pallas_call
的输入通常驻留在 HBM(主 TPU 内存)中,但传递给内核体的引用将指向内存层次结构较低级别(VMEM 或 SMEM)中的缓冲区。这使得内核体能够以非常高的速度写入和读取它们,而所有与 HBM(具有非常高的延迟)的通信都由编译器处理并与计算重叠。
此外,与GPU相比,TPU实际上是高度顺序的机器。因此,网格通常不是并行处理,而是按字典顺序顺序处理(尽管请参阅 多核TPU配置 部分了解例外情况)。这解锁了一些有趣的功能:
当两个(按字典顺序)连续的网格索引使用输入的同一部分时,由于数据已经可用,第二次迭代的HBM传输将被跳过。
内核体的多次调用可以写入输出的同一部分,而不会存在任何竞争条件的风险。然而,我们确实要求所有写入特定部分的调用必须是连续的。
输出上的“连续”限制通常意味着网格维度的一些前缀总是变化,以访问每次调用所需的输出切片,而输出窗口对于剩余的后缀保持不变。
例如,在为矩阵乘法实现Pallas TPU内核时,通常会使用一个三维网格:前两个维度对应于沿左操作数的第一轴和右操作数的第二轴进行切片。第三个和*最后一个*网格轴将平铺减少维度。对应于减少维度的网格轴必须是最后一个,因为输出窗口不沿此轴变化。然后可以使用输出引用作为部分结果的累加器。
备注
VMEM 对于这样一个低层次的内存层次结构来说相当大(16MB+),使得使用大窗口尺寸成为可能。而且,通常情况下,窗口尺寸越大,最终的硬件利用率就越好。然而,有可能指定一个窗口尺寸(连同保存溢出向量寄存器所需的空间)超过VMEM的大小。在这种情况下,你可能会看到一个低层次的编译器错误信息,抱怨内存不足的错误。
维度排序是有意义的#
在 JAX 程序中,jax.jit
内部的中间数组顺序通常对性能没有影响,因为编译器可以自由地重新排列它们。然而,由于 Pallas 旨在暴露较低级别的功能,维度顺序可能会对生成的代码质量产生重大影响。
回想一下,TPU 在 2D 向量寄存器上执行大部分计算。Pallas TPU 只会考虑将中间数组的最后两个维度映射到这些向量寄存器维度(分别是子通道和通道)。一个形状为 (n, 1, 1)
的数组保证至少需要 n
个向量寄存器来表示。如果 n
变得太大,这可能导致溢出,并可能由于内存占用过大而导致 VMEM OOM 错误。但这也不一定——底层编译器可以自由地重新排列指令以降低寄存器压力,实际上它在这方面非常擅长。尽管如此,保持最后两个维度较大(尤其是最后一个维度),同时保持前导维度较小,仍然是一个很好的经验法则。
多核 TPU 配置#
在较新的TPU代系中,芯片上的两个核心通常被抽象为一个单一设备。为了利用多个核心,Pallas必须打破顺序网格执行的保证,并且需要在核心上并行化其中一个网格轴。这是一个可选的过程。为此,pallas_call
需要一个额外的参数名为 dimension_semantics
:
该参数是一个列表,列表中的条目数与网格中的轴数相同。只有 parallel
维度可以在核心上进行分区。根据经验,维度是并行的,除非输出窗口不变化。因此,dimension_semantics
总是由一定数量的 parallel
轴和一定数量的 arbitrary
轴组成。
虽然将内核分区到一个双核TPU设备上通常会导致2倍的加速,但实际上可能会显著更小。如果主体的不同实例具有高度不同的成本,这种情况尤其明显。如果所有昂贵的步骤都被映射到一个核心,而所有廉价的步骤都被分配到另一个核心,那么第二个核心将会闲置,直到第一个核心完成其任务。
Pallas TPU 通常倾向于对尺寸为 TPU 核心数倍数的轴进行分区,并倾向于对前导网格轴进行分区。
在SMEM中放置操作数#
TPU上的大部分计算将在向量单元上进行。然而,在许多情况下,执行一些标量操作是有用的,例如,执行控制流。因此,TPU配备了单独的标量单元和与之相连的单独标量内存(SMEM)。根据经验法则,任何用于执行控制流决策的数据都应放置在SMEM中。
SMEM 是一种支持随机访问的低延迟内存,但它只允许你通过单个指令读写32位值(与4KBi粒度的VMEM事务相比非常小,但由于没有对齐要求,因此更加灵活!)。
标量内存在实现不按常规模式访问输入块的核函数时也非常有用,例如在编写块稀疏核函数时。在 Pallas 中,这可以通过将 pallas_call
的 grid
参数替换为 PrefetchScalarGridSpec
的 grid_spec
,并设置非零的 num_scalar_prefetch
参数来实现。如果 num_scalar_prefetch
为 n
,那么 pallas_call
的前 n
个参数将被放置在 SMEM 中。这些参数不应指定 BlockSpec
。但是,所有后续参数的 BlockSpec
不仅会接收网格索引,还会接收前导操作数的 SMEM 引用。
备注
我们正在为这个功能实现示例。敬请期待!
支持的数据类型#
目前 Pallas TPU 仅支持以下数据类型:
jnp.float32
jnp.bfloat16
jnp.int*
(所有精度,除了jnp.int4
)jnp.uint*
(所有精度)
计算放置#
所有标量(即0D)数组将存储在标量寄存器中,并且对其的操作将在标量核心上执行。所有其他操作(即使是单元素,但1D+数组上的操作)将在向量核心上执行。
支持的操作#
矩阵乘法#
矩阵乘法总是以 float32 格式生成结果。如果你的输入不是 float32,我们建议使用 lax.dot
并将 preferred_element_type
设置为 jnp.float32
。
在使用 lax.dot_general
时,可以将矩阵乘法操作数的最后两个维度的转置融合到操作中,这可以提高整体内核性能。
精度控制#
Pallas TPU lowering 知道 jax.default_matmul_precision
。为了获得最佳性能(和最低精度),请使用 bfloat16
。如果您关心数值精度,您可能希望将精度设置为 float32
。
警告
即使你传入32位操作数进行矩阵乘法,它们也会被四舍五入为``bfloat16``,除非请求``float32``精度。
转置#
如果值至少有4个维度,则除了最后两个轴之外的所有轴的任意转置都是自由的。否则,仅实现最后两个轴的转置。请注意,最后两个维度的某些转置可以融合到矩阵乘法中。
访问内存#
引用的一部分可以被读取或更新,这取决于实现的限制。目前,对于32位宽的输入没有限制,但对于较窄的类型,仅支持某些切片模式。在最后两个维度中,对齐到8和128的倍数且长度分别为8和128的倍数的读取和写入总是受支持的。
向量内存的读写通常发生在形状为 (8, 128)
的块上。因此,当读取或写入至少有两个维度的引用时,当内存访问的基本偏移量的索引能被块大小整除,且读取区域的大小是块大小的倍数时,可以获得最佳性能。
逐元素操作#
许多元素级操作都得到支持。值得注意的是,硬件通常只支持使用32位类型的元素级计算。当加载使用较低精度类型的操作数时,通常应先将它们向上转换为32位类型,然后再应用元素级操作。
值得注意的是,它们的成本可能会有 显著 差异。因此,我们概述了三种支持的操作类别:廉价(🟢)、中等(🌕)和昂贵(🔴)。
操作 |
成本 |
---|---|
|
🟢 |
|
🟢 |
|
🟢 |
|
🌕 |
|
🟢 |
|
🟢 |
|
🟢 |
|
🟢 |
|
🟢 |
比较 ( |
🟢 |
类型转换 ( |
🟢 |
|
🌕 |
|
🌕 |
|
🌕 |
|
🔴 |
|
🔴 |
许多 JAX 函数是基于其他 JAX 原语实现的,因此这个列表可能并不全面。例如,jax.nn.relu
是基于比较和 jnp.where
实现的,这些也可以在 Pallas 内核中工作。
数组构造器#
所有常量数组构造器都支持(jnp.ones
、jnp.zeros
、jnp.full
)。值得注意的是,截至今天,jax.random
模块与 Pallas **不**兼容。
简化#
支持求和、最大值和最小值的归约操作,但一次只能在一个数组的轴上进行。
对最后一个数组维度的缩减通常是最慢的。对倒数第二个维度的缩减速度较快,但仍比对前导维度的缩减慢。
广播#
广播的性能特征与归约非常相似。除了最后两个维度之外,沿所有维度的广播总是被支持且免费的。沿倒数第二个维度的广播较慢,而沿最后一个维度的广播最慢。
重塑#
像往常一样,除了最后两个维度之外的所有维度的重塑都是支持的,并且是免费的。
当一个重塑操作可以修改数组的最后两个维度时,仅支持以下两种情况:(1) 一些前导维度被展平到倒数第二个维度,或 (2) 它添加了一个刚刚被减少操作移除的维度。
控制流#
TPU 后端目前对控制流的支持有限。当前支持的函数是 cond
、fori_loop
和 for_loop
。然而,循环原语在编译时会完全展开,因此请尽量保持循环次数合理小。
过度使用控制流可能会导致低级代码生成出现显著的退化,建议尽可能将多个计算密集型操作压缩到单个基本块中。