矩阵乘法#

在本指南中,我们将使用 Pallas 编写一个矩阵乘法例程。我们还将讨论如何考虑 TPU 上的 matmul 性能,以及如何将 matmul 核心模板化以融合操作。

#@标题 导入
import functools
from typing import Callable

import jax
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
from jax import random
import jax.numpy as jnp
import numpy as np

背景#

矩阵乘法是现代深度学习和语言建模的核心基础线性代数运算。我们希望使用专门的加速器,如TPU和GPU,使矩阵乘法尽可能快速,这两者都有用于快速矩阵乘法的专用单元。

为了有效利用TPU进行矩阵乘法,我们需要了解几个背景概念:块矩阵乘法、切片和流水线技术。

块矩阵乘法#

假设我们想实现matmul(x, y),它通用地将一个(m, k)数组与一个(k, n)数组相乘,但有一个限制。我们只能使用基本的matmul_small,它对小矩阵进行乘法运算(比如m, k, n <= 256)。我们该如何操作呢?

矩阵乘法的一个良好特性是,每个输出块可以表示为几个较小矩阵乘法的行块和列块的和。 形式上,如果我们有输入数组 \(x \in \mathbb{R}^{m \times k}\)\(y \in \mathbb{R}^{k \times n}\),以及输出 \(z \in \mathbb{R}^{m \times n}\),我们将它们沿着大小为 \(b_m, b_k, b_n\) 的维度分解成块。

例如,\(x\) 可以被分解为:

\[\begin{split} \begin{bmatrix} x_{0, 0} & \cdots & x_{0, i_k} \\ x_{1, 0} & \cdots & x_{1, i_k} \\ \vdots & \ddots & \vdots \\ x_{i_m, 0} & \cdots & x_{i_m, i_k} \\ \end{bmatrix} \end{split}\]

其中 \(x_{ik} \in \mathbb{R}^{b_m \times b_k}\)。 (我们可以类似地分解 \(y\)\(z\)。)

对于特定的输出块 \(z_{ij}\),我们可以计算它为

\[ z_{ij} = \sum_k x_{ik} y_{kj} \]

因此,每个输出块 \(z_{ij}\) 是几个较小块矩阵乘法 \(x_{ik} y_{kj}\) 的总和。以下是我们在NumPy中实现此算法的方式:

def matmul_small(x: np.ndarray, y: np.ndarray) -> np.ndarray:
  m, k, n = x.shape[0], x.shape[1], y.shape[0]
  assert m <= 256
  assert k <= 256
  assert n <= 256
  return np.matmul(x, y)

def block_matmul(
    x: np.ndarray,
    y: np.ndarray,
    *,
    bm: int = 256,
    bk: int = 256,
    bn: int = 256,
) -> np.ndarray:
  m, k = x.shape
  _, n = y.shape

  z = np.zeros((m, n), dtype=x.dtype)
  for m_i in range(m // bm):
    for n_i in range(n // bn):
      for k_i in range(k // bk):
        m_slice = slice(m_i * bm, (m_i + 1) * bm)
        k_slice = slice(k_i * bk, (k_i + 1) * bk)
        n_slice = slice(n_i * bn, (n_i + 1) * bn)
        x_block = x[m_slice, k_slice]
        y_block = y[k_slice, n_slice]
        z[m_slice, n_slice] += matmul_small(x_block, y_block)
  return z

我们的 block_matmul 函数现在应该可以处理大于256的输入(尽管我们假设我们的输入维度能够均匀地被256整除)。

m, k, n = 4096, 4096, 4096
x = np.random.uniform(size=(m, k)).astype(np.float32)
y = np.random.uniform(size=(k, n)).astype(np.float32)
np.testing.assert_allclose(x @ y, block_matmul(x, y), atol=1e-6, rtol=1e-6)

block_matmul 将矩阵乘法分解为许多更小的乘法,通过观察到每个大小为 (bm, bn) 的输出块可以通过累加几个大小为 (bm, bk) x (bk, bn) 的矩阵乘法来计算。

TPU 和 GPU 就是这样进行矩阵乘法的!它们原生支持类似于 matmul_small 的小矩阵乘法,因此在进行更大的矩阵乘法时,为了利用这些硬件,我们将应用 block_matmul 分解。

瓷砖和流水线#

之前的指南 中,我们介绍了如何在 Pallas 中进行计算的瓷砖化和流水线。为了确保我们的计算单元始终在工作中,并且不会因内存传输而停滞,我们将下一次迭代的内存传输与当前的内存传输重叠。

在 Pallas 中,我们通过 BlockSpecgrid 来指定这一点。请注意,我们在块矩阵乘法算法中已经有一个嵌套的 for 循环。我们可以通过 grid 在 Pallas 中指定这一点。块矩阵乘法中的切片也可以通过 BlockSpec 来指定。

你的第一个矩阵乘法内核#

综合以上内容,这里是一个块矩阵乘法内核的实现,它将内存传输和计算进行流水线处理。我们创建了一个 3 维网格,对应于 NumPy 代码中的 3 层嵌套循环。请注意,尽管 MXU 只能处理小块的乘法,Pallas 会自动处理更大的块并在 MXU 上自动排列。

网格的最后一个维度对应于矩阵乘法的收缩维度,并且是一个归约维度,因此我们需要确保初始化累加器。

def matmul_kernel(x_ref, y_ref, z_ref):
  @pl.when(pl.program_id(2) == 0)
  def _():
    z_ref[...] = jnp.zeros_like(z_ref)

  z_ref[...] += x_ref[...] @ y_ref[...]

def matmul(
    x: jax.Array,
    y: jax.Array,
    *,
    bm: int = 128,
    bk: int = 128,
    bn: int = 128,
):
  m, k = x.shape
  _, n = y.shape
  return pl.pallas_call(
      matmul_kernel,
      out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),
      in_specs=[pl.BlockSpec((bm, bk), lambda i, j, k: (i, k)),
                pl.BlockSpec((bk, bn), lambda i, j, k: (k, j))],
      out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),
      grid=(m // bm, n // bn, k // bk),
      compiler_params=pltpu.TPUCompilerParams(
          dimension_semantics=("parallel", "parallel", "arbitrary")),
  )(x, y)
m, k, n = 4096, 4096, 4096
k1, k2 = random.split(random.key(0), 2)
x = random.normal(k1, (m, k), dtype=jnp.float32)
y = random.normal(k2, (k, n), dtype=jnp.float32)
np.testing.assert_array_equal(x @ y, matmul(x, y))

矩阵乘法性能#

让我们来思考如何分析矩阵乘法的性能。当我们考虑矩阵乘法的性能时,通常会关注两个方面:浮点运算的总数量(FLOPs)和内存带宽的使用量。从关于 TPUs 和流水线的指南中可以看到,为了利用 TPU 上高效的计算单元(以及一般的机器学习加速器),我们需要将输入从 HBM 复制到更接近计算单元的 VMEM。这种从 HBM 复制进出所需的时间,而一个高效的内核希望将大部分时间花在实际计算上,而不是等待这些传输。内存带宽衡量的是数据传输的速率。

快速说明:在本指南中,我们将讨论浮点运算,但希望区分 FLOPs 与 FLOP/s。 当我们说“FLOPs”时,我们指的是“浮点运算”,即运算次数。当我们说“FLOP/s”时,我们指的是“每秒浮点运算”,即执行浮点运算的速率。

在一个 (m, k) x (k, n) 的矩阵乘法中,FLOPs 的数量(大约)为 2 * m * k * n。(从技术上讲,它是 n * m * (2k - 1),但对于足够大的 k,我们的近似是足够的。)

进行矩阵乘法所需的最小内存带宽(假设使用 float32)是输入的总大小(复制到 VMEM)加上输出的大小(复制到 HBM)。因此,最小带宽使用量为 (m * k + k * n + m * n) * 4 bytes/float32。如果我们多次重新读取输入,内存使用量可能会更大,这通常是常见的情况。

一个观察是,矩阵乘法的 FLOPs 对其输入是立方的,而最小带宽使用量则是平方的。直观上,这意味着 FLOPs 的增长速度快于带宽使用量,这意味着我们的矩阵乘法越大,相对于复制我们拥有的计算越多。

def matmul_flops(m: int, k: int, n: int):
  return 2 * m * k * n

def matmul_membw(m: int, k: int, n: int, dtype: jnp.dtype):
  return (m * k + k * n + m * n) * np.dtype(dtype).itemsize

print(matmul_flops(1024, 1024, 1024))
print(matmul_membw(1024, 1024, 1024, jnp.float32))
2147483648
12582912

现在我们可以计算矩阵乘法的总FLOP数和(最小)内存带宽使用情况,让我们看看一个真实的TPU能处理什么。

这个Notebook是在TPU v5e芯片上运行的,所以我们将使用v5e的数据(如果你在运行这个Notebook,你的数据可能会有所不同)。TPU v5e的计算能力为197 TFLOP/s的bf16/f32计算和819 GB/s的内存带宽。通过查看这些数字的比率(称为算术强度),我们可以得到在变为IO瓶颈之前,这个“FLOP / 内存带宽使用”比率可以下降到多低(在TPU v5e上大约为240 FLOP/字节)。

v5e_flops = 197e12
v5e_membw = 819e9
v5e_op_intensity = v5e_flops / v5e_membw  # ~240.5

粗略来说,这些数字告诉我们,一个矩阵乘法的FLOPs应该需要 2 * m * k * n / (197 TFLOP/s) 秒,而从虚拟内存(VMEM)复制数据所需的时间应该是 (m*k + k*n + m*n) * 4字节 / 819GB/s 秒。

def matmul_flops_intensity(m: int, k: int, n: int, dtype: jnp.dtype):
  flops = matmul_flops(m, k, n)
  membw = matmul_membw(m, k, n, dtype)
  return flops / membw

这个基本计算大致告诉我们我们将多有效地使用我们的MXUs。如果我们的矩阵乘法操作强度低于我们芯片的能力,则我们的计算将是内存受限,即我们的计算单元将处于闲置状态,等待值被传输。如果矩阵乘法强度高于芯片的能力,那么我们将是计算受限

因为矩阵乘法的浮点运算次数与其输入大小的立方成正比,而内存带宽的使用与其输入大小的平方成正比,我们预计随着规模的不断增大,我们将变得计算受限,但这个交叉点是非常重要的!假设我们正在进行一个(1024, 1024) x (1024, 1024)的float32矩阵乘法。

print(f"{matmul_flops_intensity(1024, 1024, 1024, jnp.float32)} flops/byte")
170.66666666666666 flops/byte

我们的矩阵乘法的浮点运算强度低于我们芯片的能力。这不好!在这种类型的矩阵乘法中,我们很可能会受到内存的限制。然而,如果我们的输入和输出更大呢? 当我们的矩阵乘法足够大时,我们将从受内存限制转变为受计算限制。例如,如果我们有一个矩阵乘法,其中 m = k = n,那么在 TPU v5e 上,当 2m**3 / 12m**2 > 240 或者当 m = k = n > 1440 时,我们将会发生这个转变。

bfloat16 矩阵乘法#

为了使矩阵乘法在TPU上的计算受限,我们还可以对输入和输出使用较小的数据类型。我们之前的例子使用了float32输入和输出,但TPU v5e还支持bfloat16数据类型(16位浮点格式,也称为bf16)进行矩阵乘法。在TPU v5e上,我们将拥有相同的FLOP/s,但会减少一半的内存带宽使用。这使得对较小矩阵的计算受限变得更容易。让我们看看在进行1024 x 1024 x 1024的bf16矩阵乘法时我们的强度是多少:

print(f"{matmul_flops_intensity(1024, 1024, 1024, jnp.bfloat16)} flops/byte")
341.3333333333333 flops/byte

我们现在有一个计算受限的矩阵乘法!

让我们为我们的矩阵乘法内核添加 bf16 支持。

本地 MXU 的 bf16 矩阵乘法例程接受两个输入的 bf16 矩阵,并以 f32 进行累积。我们将通过将 preferred_element_type=jnp.float32 传递给 jnp.matmul 来触发这个例程。我们还需要一个 f32 的累加器 Ref。然后,我们将在将输出回写到 HBM 之前将其向下转换回 bf16。这样我们就不会丢失任何精度,不做任何额外的转换,同时仍然保留 bf16 的内存带宽节省。

请注意,现在分配临时空间的唯一方法是通过 pltpu.PrefetchScalarGridSpec。暂时不用担心它具体做什么——你现在需要知道的是,它允许你在 VMEM 中分配临时空间。

def matmul_kernel(x_ref, y_ref, z_ref, acc_ref, *, nsteps):
  @pl.when(pl.program_id(2) == 0)
  def _():
    acc_ref[...] = jnp.zeros_like(acc_ref)

  acc_ref[...] += jnp.dot(
      x_ref[...], y_ref[...], preferred_element_type=jnp.float32
  )

  @pl.when(pl.program_id(2) == nsteps - 1)
  def _():
    z_ref[...] = acc_ref[...].astype(z_ref.dtype)


@functools.partial(jax.jit, static_argnames=['bm', 'bk', 'bn'])
def matmul(
    x: jax.Array,
    y: jax.Array,
    *,
    bm: int = 128,
    bk: int = 128,
    bn: int = 128,
):
  m, k = x.shape
  _, n = y.shape
  return pl.pallas_call(
      functools.partial(matmul_kernel, nsteps=k // bk),
      grid_spec=pltpu.PrefetchScalarGridSpec(
        num_scalar_prefetch=0,
        in_specs=[
            pl.BlockSpec((bm, bk), lambda i, j, k: (i, k)),
            pl.BlockSpec((bk, bn), lambda i, j, k: (k, j)),
        ],
        out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),
        scratch_shapes=[pltpu.VMEM((bm, bn), jnp.float32)],
        grid=(m // bm, n // bn, k // bk),
      ),
      out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),
      compiler_params=pltpu.TPUCompilerParams(
          dimension_semantics=("parallel", "parallel", "arbitrary")),
  )(x, y)
m, k, n = 4096, 4096, 4096
k1, k2 = random.split(random.key(0), 2)
x = random.normal(k1, (m, k), dtype=jnp.bfloat16)
y = random.normal(k2, (k, n), dtype=jnp.bfloat16)
np.testing.assert_array_equal(x @ y, matmul(x, y))

流水线内核的性能#

我们上述关于FLOPs与内存使用的分析适用于较粗粒度的层面,即当我们在观察整个矩阵乘法的大小时。然而,请记住,实际上我们是在流水线处理一个被阻塞的矩阵乘法,这意味着我们有一个循环,在其中我们正在处理较小块的矩阵乘法。

这意味着我们实际上关心的是每个内核实例的FLOPs与内存带宽使用情况,而不是全局的FLOPs与内存带宽使用情况。因此,块大小bmbkbn对性能而言至关重要。即使我们有世界上最大的矩阵,如果我们选择非常小的bmbkbn,我们将受到内存的限制,因为每次调用内核时,FLOPs太少,无法掩盖后台发生的内存传输。

因此,直觉应该是:为了计算限制,使块尽可能大!主要有两个约束条件:

  1. VMEM使用:块越大,我们使用的VMEM就越多。块大到一定程度后,我们会耗尽内存。

  2. 流水线气泡:相对于矩阵大小,块越大,流水线中的循环迭代次数就越少。这将使得流水线开始和结束时的气泡相对于总流水线的大小更大,而这种开销可能是非微不足道的。

在Pallas中获得良好的矩阵乘法性能归结为选择合适的块大小以平衡这个优化问题。实际上,我们通常会在一组候选块大小中进行搜索,分析内核,并选择最佳的一个。

现在,让我们进行一些非常简单的计时实验。我们将使用timeit来测量每个内核运行所需的时间。请注意,这是内核实际运行时间的上限,因为我们正在使用timeit测量Python调度和其他开销。我们将计算通过这种方式获得的FLOP/s数量,并计算与芯片所提供的比较下的利用率百分比,并使用一些合理的块大小来验证我们的直觉。

import timeit

def benchmark(f, ntrials: int = 100):
  def run(*args, **kwargs):
    # 先编译函数
    jax.block_until_ready(f(*args, **kwargs))
    # 时间函数
    result = timeit.timeit(lambda: jax.block_until_ready(f(*args, **kwargs)),
                           number=ntrials)
    time = result / ntrials
    # print(f"Time: {time}")
    return time
  return run

def analyze_matmul(m: int, k: int, n: int, dtype: np.dtype,
                   mm_func):
  x = jnp.ones((m, k), dtype=dtype)
  y = jnp.ones((k, n), dtype=dtype)
  time = benchmark(mm_func)(x, y)
  print(f"----- {m} x {k} x {n} -----")
  print("Matmul time: ", time)
  mm_flops = matmul_flops(m, k, n) / time
  print("Matmul FLOP/s: ", mm_flops)
  print(f"FLOP/s utilization: {mm_flops / v5e_flops * 100:.4f}%")
  print()

print("================bm=128, bk=128, bn=128===================")
mm = functools.partial(matmul, bm=128, bk=128, bn=128)
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm)

print("================bm=512, bk=1024, bn=1024===================")
mm = functools.partial(matmul, bm=512, bk=1024, bn=1024)
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm)
================bm=128, bk=128, bn=128===================
----- 1024 x 1024 x 1024 -----
Matmul time:  0.00029766598949208854
Matmul FLOP/s:  7214407167121.377
FLOP/s utilization: 3.6621%

----- 4096 x 4096 x 4096 -----
Matmul time:  0.011771515250438824
Matmul FLOP/s:  11675553278230.387
FLOP/s utilization: 5.9267%

----- 8192 x 8192 x 8192 -----
Matmul time:  0.09183577066054567
Matmul FLOP/s:  11972585626140.668
FLOP/s utilization: 6.0775%

================bm=512, bk=1024, bn=1024===================
----- 1024 x 1024 x 1024 -----
Matmul time:  0.00012708659982308746
Matmul FLOP/s:  16897797651282.135
FLOP/s utilization: 8.5776%

----- 4096 x 4096 x 4096 -----
Matmul time:  0.00088908776990138
Matmul FLOP/s:  154584235803001.88
FLOP/s utilization: 78.4692%

----- 8192 x 8192 x 8192 -----
Matmul time:  0.006099433819763363
Matmul FLOP/s:  180264539343531.62
FLOP/s utilization: 91.5048%

更大的块大小有很大帮助!在较大的矩阵乘法中,我们获得了相当好的利用率(80-90%),但最小的矩阵乘法似乎很难实现良好的性能。

让我们将其与XLA的矩阵乘法进行比较。我们不期望Pallas比XLA表现得更好,因为XLA在生成矩阵乘法方面非常出色,但希望我们能接近。 通过更仔细的块大小调优(留作未来工作),我们也可以达到XLA的性能。

print("================ XLA matmul ===================")
mm = jnp.matmul
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm)
================ XLA matmul ===================
----- 1024 x 1024 x 1024 -----
Matmul time:  0.00011943008983507753
Matmul FLOP/s:  17981093801113.996
FLOP/s utilization: 9.1275%

----- 4096 x 4096 x 4096 -----
Matmul time:  0.0008272899803705514
Matmul FLOP/s:  166131533963991.34
FLOP/s utilization: 84.3307%

----- 8192 x 8192 x 8192 -----
Matmul time:  0.006047147869830951
Matmul FLOP/s:  181823175395037.44
FLOP/s utilization: 92.2960%

Pallas,通过一些非常基础的调优,性能数字与XLA相差无几!通过尝试更多的块大小,我们应该能完全缩小这个差距。

矩阵乘法的模板化#

现在我们拥有了一个基本的矩阵乘法内核,我们可以尝试将操作融合到其中。

融合右侧转置#

常见的第一步是融合转置。我们是什么意思呢?假设我们想计算 x @ y.T 而不是 x @ y。简单来说,我们可以首先计算 y.T,然后将其传递给我们的高效矩阵乘法内核。然而,操作 y.T 本身并不是免费的——它涉及复制 O(n^2) 的数据。理想情况下,我们可以在执行矩阵乘法时同时计算转置,即将其与矩阵乘法“融合”在一起。

加速器通常支持本地矩阵乘法例程,可以融合右侧的转置。例如,在 TPU v5e 中,MXU 允许我们对小数组执行 x @ y.T。我们可以通过 jax.lax.dot_general 调用此例程,这比先进行转置再进行矩阵乘法更高效。

def matmul_kernel(x_ref, y_ref, z_ref, acc_ref, *, nsteps, transpose_rhs):
  @pl.when(pl.program_id(2) == 0)
  def _():
    acc_ref[...] = jnp.zeros_like(acc_ref)

  # dot_general 期望一个数据结构(contraction_dims, batch_dims),
  # 其中,`contraction_dims` 是指 LHS 和 RHS 中将被
  # 在 matmul 中会被缩减;而 batch_dims 则相反,
  # 循环遍历。剩余的维度将是输入和输出维度
  # 矩阵乘法。
  if transpose_rhs:
    dims = ((1,), (1,)), ((), ())
  else:
    dims = ((1,), (0,)), ((), ())

  acc_ref[...] += jax.lax.dot_general(
      x_ref[...], y_ref[...], dims, preferred_element_type=jnp.float32,
  )

  @pl.when(pl.program_id(2) == nsteps - 1)
  def _():
    z_ref[...] = acc_ref[...].astype(z_ref.dtype)


@functools.partial(jax.jit, static_argnames=['bm', 'bk', 'bn', 'transpose_rhs'])
def matmul(
    x: jax.Array,
    y: jax.Array,
    *,
    bm: int = 128,
    bk: int = 128,
    bn: int = 128,
    transpose_rhs: bool = False,
):
  if transpose_rhs:
    y = y.swapaxes(0, 1)
    y_block_spec = pl.BlockSpec((bn, bk), lambda i, j, k: (j, k))
  else:
    y_block_spec = pl.BlockSpec((bk, bn), lambda i, j, k: (k, j))
  m, k = x.shape
  _, n = y.shape
  return pl.pallas_call(
      functools.partial(matmul_kernel, nsteps=k // bk, transpose_rhs=transpose_rhs),
      grid_spec=pltpu.PrefetchScalarGridSpec(
        num_scalar_prefetch=0,
        in_specs=[
            pl.BlockSpec((bm, bk), lambda i, j, k: (i, k)),
            y_block_spec,
        ],
        out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),
        scratch_shapes=[pltpu.VMEM((bm, bn), jnp.float32)],
        grid=(m // bm, n // bn, k // bk),
      ),
      out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),
      compiler_params=pltpu.TPUCompilerParams(
          dimension_semantics=("parallel", "parallel", "arbitrary")),
  )(x, y)

我们在 matmul 函数内部进行转置 (y = y.swapaxes(0, 1))。这是因为在 JIT 编译的 JAX 计算中,维度的顺序是纯粹的逻辑,而不是物理的,因此重新排列维度并不意味着物理布局的差异。然而,当我们将一个数组传入 pallas_call 时,我们确实强制实施主到次的维度顺序约束。通过在 matmul 函数内部转置 y,我们要求 y 以转置布局 (n, k) 而不是通常的 (k, n)。然而,用户仍将以逻辑 (n, k) 维度传入数组。

注意:为了对转置进行基准测试,我们实际上希望 y 在传入内核时处于物理转置布局,因此我们不测量重新布局的时间。在包装函数中,我们会将其(逻辑上)转置回 (n, k),然后再传入 matmul,因为 matmul 期望逻辑 (n, k) 的维度顺序。

def analyze_matmul(m: int, k: int, n: int, dtype: np.dtype,
                   mm_func, transpose_rhs: bool = False):
  x = jnp.ones((m, k), dtype=dtype)
  if transpose_rhs:
    y = jnp.ones((n, k), dtype=dtype)
    @jax.jit
    def _wrapper(x, y):
      y = y.swapaxes(0, 1)
      return mm_func(x, y, transpose_rhs=True)
  else:
    y = jnp.ones((k, n), dtype=dtype)
    _wrapper = mm_func
  time = benchmark(_wrapper)(x, y)
  print(f"----- {m} x {k} x {n} -----")
  print("Matmul time: ", time)
  mm_flops = matmul_flops(m, k, n) / time
  print("Matmul FLOP/s: ", mm_flops)
  print(f"FLOP/s utilization: {mm_flops / v5e_flops * 100:.4f}%")
  print()

print("================bm=128, bk=128, bn=128===================")
mm = functools.partial(matmul, bm=128, bk=128, bn=128)
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm, transpose_rhs=True)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm, transpose_rhs=True)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm, transpose_rhs=True)

print("================bm=512, bk=1024, bn=1024===================")
mm = functools.partial(matmul, bm=512, bk=1024, bn=1024)
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm, transpose_rhs=True)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm, transpose_rhs=True)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm, transpose_rhs=True)
================bm=128, bk=128, bn=128===================
----- 1024 x 1024 x 1024 -----
Matmul time:  0.0003029372810851783
Matmul FLOP/s:  7088872126624.065
FLOP/s utilization: 3.5984%

----- 4096 x 4096 x 4096 -----
Matmul time:  0.012017967159627005
Matmul FLOP/s:  11436123235026.848
FLOP/s utilization: 5.8051%

----- 8192 x 8192 x 8192 -----
Matmul time:  0.09500920018996112
Matmul FLOP/s:  11572685861765.383
FLOP/s utilization: 5.8745%

================bm=512, bk=1024, bn=1024===================
----- 1024 x 1024 x 1024 -----
Matmul time:  0.00012131539988331496
Matmul FLOP/s:  17701657415839.363
FLOP/s utilization: 8.9856%

----- 4096 x 4096 x 4096 -----
Matmul time:  0.0008790623804088682
Matmul FLOP/s:  156347213275211.03
FLOP/s utilization: 79.3641%

----- 8192 x 8192 x 8192 -----
Matmul time:  0.006107717020204291
Matmul FLOP/s:  180020067095253.78
FLOP/s utilization: 91.3807%

看看我们如何在额外的转置下仍然得到相同的利用率!

融合激活函数#

在激活函数中进行融合也是非常常见的。这确保我们不在一个高效的、计算密集型的矩阵乘法内核后面跟随一个缓慢的、内存密集型的激活内核。

def matmul_kernel(
    x_ref, y_ref, z_ref, acc_ref, *, nsteps, transpose_rhs, activation
):
  @pl.when(pl.program_id(2) == 0)
  def _():
    acc_ref[...] = jnp.zeros_like(acc_ref)

  if transpose_rhs:
    dims = ((1,), (1,)), ((), ())
  else:
    dims = ((1,), (0,)), ((), ())

  acc_ref[...] += jax.lax.dot_general(
      x_ref[...],
      y_ref[...],
      dims,
      preferred_element_type=jnp.float32,
  )

  @pl.when(pl.program_id(2) == nsteps - 1)
  def _():
    z_ref[...] = activation(acc_ref[...]).astype(z_ref.dtype)


@functools.partial(jax.jit, static_argnames=['bm', 'bk', 'bn', 'activation'])
def matmul(
    x: jax.Array,
    y: jax.Array,
    *,
    bm: int = 128,
    bk: int = 128,
    bn: int = 128,
    transpose_rhs: bool = False,
    activation: Callable[[jax.Array], jax.Array] = lambda x: x,
):
  if transpose_rhs:
    y = y.swapaxes(0, 1)
    y_block_spec = pl.BlockSpec((bn, bk), lambda i, j, k: (j, k))
  else:
    y_block_spec = pl.BlockSpec((bk, bn), lambda i, j, k: (k, j))
  m, k = x.shape
  _, n = y.shape
  return pl.pallas_call(
      functools.partial(
          matmul_kernel,
          nsteps=k // bk,
          transpose_rhs=transpose_rhs,
          activation=activation,
      ),
      grid_spec=pltpu.PrefetchScalarGridSpec(
          num_scalar_prefetch=0,
          in_specs=[
              pl.BlockSpec((bm, bk), lambda i, j, k: (i, k)),
              y_block_spec,
          ],
          out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),
          scratch_shapes=[pltpu.VMEM((bm, bn), jnp.float32)],
          grid=(m // bm, n // bn, k // bk),
      ),
      out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),
      compiler_params=pltpu.TPUCompilerParams(
          dimension_semantics=("parallel", "parallel", "arbitrary")),
  )(x, y)
def analyze_matmul(m: int, k: int, n: int, dtype: np.dtype,
                   mm_func, transpose_rhs: bool = False,
                   activation = lambda x: x):
  x = jnp.ones((m, k), dtype=dtype)
  if transpose_rhs:
    y = jnp.ones((n, k), dtype=dtype)
    @jax.jit
    def _wrapper(x, y):
      y = y.swapaxes(0, 1)
      return mm_func(x, y, transpose_rhs=True, activation=activation)
  else:
    y = jnp.ones((k, n), dtype=dtype)
    _wrapper = functools.partial(mm_func, activation=activation)
  time = benchmark(_wrapper)(x, y)
  print(f"----- {m} x {k} x {n} -----")
  print("Matmul time: ", time)
  mm_flops = matmul_flops(m, k, n) / time
  print("Matmul FLOP/s: ", mm_flops)
  print(f"FLOP/s utilization: {mm_flops / v5e_flops * 100:.4f}%")
  print()


activation = jax.nn.relu
print("================bm=128, bk=128, bn=128===================")
mm = functools.partial(matmul, bm=128, bk=128, bn=128)
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm, activation=activation)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm, activation=activation)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm, activation=activation)

print("================bm=512, bk=1024, bn=1024===================")
mm = functools.partial(matmul, bm=512, bk=1024, bn=1024)
analyze_matmul(1024, 1024, 1024, jnp.bfloat16, mm, activation=activation)
analyze_matmul(4096, 4096, 4096, jnp.bfloat16, mm, activation=activation)
analyze_matmul(8192, 8192, 8192, jnp.bfloat16, mm, activation=activation)
================bm=128, bk=128, bn=128===================
----- 1024 x 1024 x 1024 -----
Matmul time:  0.00030103540048003196
Matmul FLOP/s:  7133658182976.541
FLOP/s utilization: 3.6211%

----- 4096 x 4096 x 4096 -----
Matmul time:  0.011807117109419778
Matmul FLOP/s:  11640348122095.826
FLOP/s utilization: 5.9088%

----- 8192 x 8192 x 8192 -----
Matmul time:  0.09181861146935262
Matmul FLOP/s:  11974823079773.941
FLOP/s utilization: 6.0786%

================bm=512, bk=1024, bn=1024===================
----- 1024 x 1024 x 1024 -----
Matmul time:  0.00012622540001757442
Matmul FLOP/s:  17013086492108.6
FLOP/s utilization: 8.6361%

----- 4096 x 4096 x 4096 -----
Matmul time:  0.000896632740041241
Matmul FLOP/s:  153283442968721.44
FLOP/s utilization: 77.8089%

----- 8192 x 8192 x 8192 -----
Matmul time:  0.006130605939542875
Matmul FLOP/s:  179347953304919.88
FLOP/s utilization: 91.0396%

额外的融合激活几乎对我们的利用率没有影响!

结论#

在本指南中,我们介绍了如何使用Pallas在TPU上编写高效的矩阵乘法。我们讨论了分块矩阵乘法和流水线技术,如何分析TPU矩阵乘法的性能,以及如何编写高效的bf16矩阵乘法。最后,我们总结了矩阵乘法的模板,以支持融合转置和融合激活函数。

留给读者的练习:

  • 添加输入融合的支持。有时我们希望将一个操作融合到矩阵乘法的输入中。尝试进一步模板化矩阵乘法以支持这一点。

  • 添加对int8矩阵乘法的支持。TPU v5原生支持int8矩阵乘法,其FLOPs是bf16的两倍。尝试添加对此的支持,并查看可能的使用率。

  • matmul函数添加反向传播支持。您可以使用jax.custom_vjp来实现这一点。