jax.experimental.pallas.pallas_call#
- jax.experimental.pallas.pallas_call(kernel, out_shape, *, grid_spec=None, grid=(), in_specs=NoBlockSpec, out_specs=NoBlockSpec, input_output_aliases={}, debug=False, interpret=False, name=None, compiler_params=None, cost_estimate=None)[源代码][源代码]#
在某些输入上调用 Pallas 内核。
参见 Pallas 快速入门。
- 参数:
kernel (Callable[..., None]) – 内核函数,接收每个输入和输出的 Ref。Ref 的形状由相应
in_specs
和out_specs
中的block_shape
给出。out_shape (Any) – 一个
jax.ShapeDtypeStruct
的 PyTree,描述了输出的形状和数据类型。grid_spec (GridSpec | None) – 指定
grid
、in_specs
和out_specs
的另一种方法。如果给出这些参数,则不得同时给出其他参数。grid (TupleGrid) – 迭代空间,作为一个整数元组。内核执行的次数与
prod(grid)
相同。详情请参见 grid,即循环中的内核。in_specs (BlockSpecTree) – 一个与位置参数结构匹配的
jax.experimental.pallas.BlockSpec
的 PyTree。in_specs
的默认值指定所有输入的整个数组,例如,pl.BlockSpec(x.shape, lambda *indices: (0,) * x.ndim)
。详情请参见 BlockSpec,又名如何分割输入。out_specs (BlockSpecTree) – 一个与输出结构匹配的
jax.experimental.pallas.BlockSpec
的 PyTree。out_specs
的默认值指定整个数组,例如pl.BlockSpec(x.shape, lambda *indices: (0,) * x.ndim)
。详情请参见 BlockSpec,又名如何分割输入。input_output_aliases (dict[int, int]) – 一个字典,将某些输入的索引映射到别名它们的输出的索引。这些索引在展平的输入和输出中。
debug (bool) – 如果为真,Pallas 会在处理内核时打印各种中间形式。
interpret (bool) – 将
pallas_call
作为jax.jit
运行,该jax.jit
是对网格的扫描,其主体是作为 JAX 函数降低的内核。这不需要 TPU 或 GPU,并且是唯一在 CPU 上运行 Pallas 内核的方法。这对于调试很有用。name (str | None) – 如果存在,指定在调试和错误消息中用于此内核调用的名称。我们将文件和定义内核函数的行附加到此名称上,例如:{name} for kernel function {kernel_name} at {file}:{line}。如果不存在,则使用 {kernel_name} at {file}:{line}。
compiler_params (dict[str, Any] | pallas_core.CompilerParams | None) – 可选的编译器参数。如果提供了一个字典,它应该是 {平台: {参数名: 参数值}} 的形式,其中平台可以是 ‘mosaic’ 或 ‘triton’。也可以传入 jax.experimental.pallas.tpu.TPUCompilerParams 用于 TPU 和 jax.experimental.pallas.gpu.TritonCompilerParams 用于 Triton/GPU。
cost_estimate (CostEstimate | None)
- 返回:
一个可以在多个位置数组参数上调用以调用Pallas内核的函数。
- 返回类型:
Callable[…, Any]