网格和BlockSpecs#
grid
,即循环中的内核#
在使用 jax.experimental.pallas.pallas_call()
时,内核函数会根据 pallas_call
的 grid
参数指定的不同输入执行多次。从概念上讲:
pl.pallas_call(some_kernel, grid=(n,))(...)
映射到
for i in range(n):
some_kernel(...)
网格可以被泛化为多维的,对应于嵌套循环。例如,
pl.pallas_call(some_kernel, grid=(n, m))(...)
相当于
for i in range(n):
for j in range(m):
some_kernel(...)
这可以推广到任何整数元组(长度为 d
的网格将对应于 d
个嵌套循环)。内核执行的次数与 prod(grid)
相同。默认的网格值 ()
导致内核调用一次。这些调用中的每一个都被称为一个“程序”。要访问内核当前正在执行的程序(即网格的哪个元素),我们使用 jax.experimental.pallas.program_id()
。例如,对于调用 (1, 2)
,program_id(axis=0)
返回 1
,program_id(axis=1)
返回 2
。你也可以使用 jax.experimental.pallas.num_programs()
来获取给定轴的网格大小。
这是一个使用 grid
和 program_id
的内核实例。
>>> import jax
>>> from jax.experimental import pallas as pl
>>> def iota_kernel(o_ref):
... i = pl.program_id(0)
... o_ref[i] = i
我们现在使用 pallas_call
执行它,并附加一个 grid
参数。
>>> def iota(size: int):
... return pl.pallas_call(iota_kernel,
... out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),
... grid=(size,), interpret=True)()
>>> iota(8)
Array([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32)
在GPU上,每个程序在独立的线程块中并行执行。因此,我们需要考虑对HBM写入时的竞争条件。一个合理的方法是以这样的方式编写内核,即不同的程序写入HBM中不相交的位置,以避免这些并行写入。
在TPU上,程序是结合并行和顺序(取决于架构)执行的,因此有一些不同的考虑因素。请参阅 Pallas TPU文档。
BlockSpec
,又名如何分割输入#
结合 grid
参数,我们需要向 Pallas 提供每次调用时如何分割输入的信息。具体来说,我们需要提供一个映射,将 循环的迭代 映射到 要操作的输入和输出的哪个块。这是通过 jax.experimental.pallas.BlockSpec
对象提供的。
在我们深入了解 BlockSpec
的细节之前,你可能想重新访问 Pallas 快速入门 BlockSpecs 示例。
BlockSpec
通过 in_specs
和 out_specs
提供给 pallas_call
,每个输入和输出分别对应一个。
首先,我们讨论当 indexing_mode == pl.Blocked()
时 BlockSpec
的语义。
非正式地,BlockSpec
的 index_map
接受调用索引(与 grid
元组的长度相同)作为参数,并返回 块索引(每个整体数组的每个轴对应一个块索引)。然后,每个块索引乘以 block_shape
中相应的轴大小,以获得相应数组轴上的实际元素索引。
备注
并非所有块形状都受支持。
在TPU上,仅支持秩至少为1的块。此外,块形状的最后两个维度必须等于整个数组的相应维度,或者分别能被8和128整除。对于秩为1的块,块维度必须等于数组维度,或者能被
128 * (32 / bitwidth(dtype))
整除。在GPU上,块的大小本身不受限制,但每个操作必须对大小为2的幂的数组进行操作。
如果块形状不能均匀地分割整体形状,那么每个轴上的最后一次迭代仍然会接收到 block_shape
块的引用,但越界的元素在输入时会被填充,在输出时会被丢弃。填充的值是未指定的,你应该假设它们是垃圾数据。在 interpret=True
模式下,我们用 NaN 填充浮点值,以给用户一个发现访问越界元素的机会,但这种行为不应依赖。请注意,每个块中至少有一个元素必须在边界内。
更准确地说,输入 x
的形状 x_shape
的每个轴的切片是按照下面的函数 slice_for_invocation
计算的:
>>> def slices_for_invocation(x_shape: tuple[int, ...],
... x_spec: pl.BlockSpec,
... grid: tuple[int, ...],
... invocation_indices: tuple[int, ...]) -> tuple[slice, ...]:
... assert len(invocation_indices) == len(grid)
... assert all(0 <= i < grid_size for i, grid_size in zip(invocation_indices, grid))
... block_indices = x_spec.index_map(*invocation_indices)
... assert len(x_shape) == len(x_spec.block_shape) == len(block_indices)
... elem_indices = []
... for x_size, block_size, block_idx in zip(x_shape, x_spec.block_shape, block_indices):
... start_idx = block_idx * block_size
... # At least one element of the block must be within bounds
... assert start_idx < x_size
... elem_indices.append(slice(start_idx, start_idx + block_size))
... return elem_indices
例如:
>>> slices_for_invocation(x_shape=(100, 100),
... x_spec = pl.BlockSpec((10, 20), lambda i, j: (i, j)),
... grid = (10, 5),
... invocation_indices = (2, 4))
[slice(20, 30, None), slice(80, 100, None)]
>>> # Same shape of the array and blocks, but we iterate over each block 4 times
>>> slices_for_invocation(x_shape=(100, 100),
... x_spec = pl.BlockSpec((10, 20), lambda i, j, k: (i, j)),
... grid = (10, 5, 4),
... invocation_indices = (2, 4, 0))
[slice(20, 30, None), slice(80, 100, None)]
>>> # An example when the block is partially out-of-bounds in the 2nd axis.
>>> slices_for_invocation(x_shape=(100, 90),
... x_spec = pl.BlockSpec((10, 20), lambda i, j: (i, j)),
... grid = (10, 5),
... invocation_indices = (2, 4))
[slice(20, 30, None), slice(80, 100, None)]
下面定义的函数 show_program_ids
使用 Pallas 来显示调用索引。iota_2D_kernel
将用一个十进制数填充每个输出块,其中第一个数字表示沿第一个轴的调用索引,第二个数字表示沿第二个轴的调用索引:
>>> def show_program_ids(x_shape, block_shape, grid,
... index_map=lambda i, j: (i, j),
... indexing_mode=pl.Blocked()):
... def program_ids_kernel(o_ref): # Fill the output block with 10*program_id(1) + program_id(0)
... axes = 0
... for axis in range(len(grid)):
... axes += pl.program_id(axis) * 10**(len(grid) - 1 - axis)
... o_ref[...] = jnp.full(o_ref.shape, axes)
... res = pl.pallas_call(program_ids_kernel,
... out_shape=jax.ShapeDtypeStruct(x_shape, dtype=np.int32),
... grid=grid,
... in_specs=[],
... out_specs=pl.BlockSpec(block_shape, index_map, indexing_mode=indexing_mode),
... interpret=True)()
... print(res)
例如:
>>> show_program_ids(x_shape=(8, 6), block_shape=(2, 3), grid=(4, 2),
... index_map=lambda i, j: (i, j))
[[ 0 0 0 1 1 1]
[ 0 0 0 1 1 1]
[10 10 10 11 11 11]
[10 10 10 11 11 11]
[20 20 20 21 21 21]
[20 20 20 21 21 21]
[30 30 30 31 31 31]
[30 30 30 31 31 31]]
>>> # An example with out-of-bounds accesses
>>> show_program_ids(x_shape=(7, 5), block_shape=(2, 3), grid=(4, 2),
... index_map=lambda i, j: (i, j))
[[ 0 0 0 1 1]
[ 0 0 0 1 1]
[10 10 10 11 11]
[10 10 10 11 11]
[20 20 20 21 21]
[20 20 20 21 21]
[30 30 30 31 31]]
>>> # It is allowed for the shape to be smaller than block_shape
>>> show_program_ids(x_shape=(1, 2), block_shape=(2, 3), grid=(1, 1),
... index_map=lambda i, j: (i, j))
[[0 0]]
当多次调用写入输出数组的相同元素时,结果取决于平台。
在下面的示例中,我们有一个3D网格,最后一个网格维度在块选择中未被使用(index_map=lambda i, j, k: (i, j)
)。因此,我们迭代了相同的输出块10次。下面显示的输出是在CPU上使用interpret=True
模式生成的,该模式目前按顺序执行调用。在TPU上,程序以并行和顺序的组合方式执行,此函数生成所示的输出。请参阅Pallas TPU文档。
>>> show_program_ids(x_shape=(8, 6), block_shape=(2, 3), grid=(4, 2, 10),
... index_map=lambda i, j, k: (i, j))
[[ 9 9 9 19 19 19]
[ 9 9 9 19 19 19]
[109 109 109 119 119 119]
[109 109 109 119 119 119]
[209 209 209 219 219 219]
[209 209 209 219 219 219]
[309 309 309 319 319 319]
[309 309 309 319 319 319]]
在 block_shape
中出现的 None
值作为维度值时,其行为类似于值 1
,除了相应的块轴被压缩。在下面的例子中,观察到当块形状被指定为 (None, 2)
时,o_ref
的形状是 (2,)(前导维度被压缩)。
>>> def kernel(o_ref):
... assert o_ref.shape == (2,)
... o_ref[...] = jnp.full((2,), 10 * pl.program_id(1) + pl.program_id(0))
>>> pl.pallas_call(kernel,
... jax.ShapeDtypeStruct((3, 4), dtype=np.int32),
... out_specs=pl.BlockSpec((None, 2), lambda i, j: (i, j)),
... grid=(3, 2), interpret=True)()
Array([[ 0, 0, 10, 10],
[ 1, 1, 11, 11],
[ 2, 2, 12, 12]], dtype=int32)
当我们构建一个 BlockSpec
时,我们可以为 block_shape
参数使用值 None
,在这种情况下,使用整个数组的形状作为 block_shape
。如果我们为 index_map
参数使用值 None
,那么将使用一个默认的索引映射函数,该函数返回一个零元组:index_map=lambda *invocation_indices: (0,) * len(block_shape)
。
>>> show_program_ids(x_shape=(4, 4), block_shape=None, grid=(2, 3),
... index_map=None)
[[12 12 12 12]
[12 12 12 12]
[12 12 12 12]
[12 12 12 12]]
>>> show_program_ids(x_shape=(4, 4), block_shape=(4, 4), grid=(2, 3),
... index_map=None)
[[12 12 12 12]
[12 12 12 12]
[12 12 12 12]
[12 12 12 12]]
“未阻塞”索引模式#
上述行为适用于 indexing_mode=pl.Blocked()
。当使用 pl.Unblocked
索引模式时,索引映射函数返回的值直接用作数组索引,而不首先按块大小进行缩放。在使用非阻塞模式时,您可以为数组指定虚拟填充,作为每个维度的低-高填充元组:行为就像在输入时对整个数组进行了填充。对于非阻塞模式中的填充值不做保证,类似于阻塞索引模式中当块形状不分割整个数组形状时的填充值。
当前仅在 TPU 上支持未阻塞模式。
>>> # unblocked without padding
>>> show_program_ids(x_shape=(8, 6), block_shape=(2, 3), grid=(4, 2),
... index_map=lambda i, j: (2*i, 3*j),
... indexing_mode=pl.Unblocked())
[[ 0 0 0 1 1 1]
[ 0 0 0 1 1 1]
[10 10 10 11 11 11]
[10 10 10 11 11 11]
[20 20 20 21 21 21]
[20 20 20 21 21 21]
[30 30 30 31 31 31]
[30 30 30 31 31 31]]
>>> # unblocked, first pad the array with 1 row and 2 columns.
>>> show_program_ids(x_shape=(7, 7), block_shape=(2, 3), grid=(4, 3),
... index_map=lambda i, j: (2*i, 3*j),
... indexing_mode=pl.Unblocked(((1, 0), (2, 0))))
[[ 0 1 1 1 2 2 2]
[10 11 11 11 12 12 12]
[10 11 11 11 12 12 12]
[20 21 21 21 22 22 22]
[20 21 21 21 22 22 22]
[30 31 31 31 32 32 32]
[30 31 31 31 32 32 32]]