jax.experimental.pallas 模块

目录

jax.experimental.pallas 模块#

用于Pallas的模块,Pallas是JAX的扩展,用于自定义内核。

请参阅 Pallas 文档,网址为 https://jax.readthedocs.io/en/latest/pallas.html

#

BlockSpec([block_shape, index_map, ...])

指定如何为每次内核调用对数组进行切片。

GridSpec([grid, in_specs, out_specs])

jax.experimental.pallas.pallas_call() 编码网格参数。

Slice(start, size[, stride])

一个带有起始索引和大小的切片。

函数#

pallas_call(kernel, out_shape, *[, ...])

在某些输入上调用 Pallas 内核。

program_id(axis)

返回网格沿给定轴的内核执行位置。

num_programs(axis)

返回网格沿给定轴的大小。

load(x_ref_or_view, idx, *[, mask, other, ...])

从给定的索引加载数组并返回。

store(x_ref_or_view, idx, val, *[, mask, ...])

在给定的索引处存储一个值。

swap(x_ref_or_view, idx, val, *[, mask, ...])

交换给定索引处的值并返回旧值。

atomic_and(x_ref_or_view, idx, val, *[, mask])

原子性地计算 x_ref_or_view[idx] &= val

atomic_add(x_ref_or_view, idx, val, *[, mask])

原子性地计算 x_ref_or_view[idx] += val

atomic_cas(ref, cmp, val)

在引用中执行原子的比较和交换操作,将其值替换为给定的值。

atomic_max(x_ref_or_view, idx, val, *[, mask])

原子性地计算 x_ref_or_view[idx] = max(x_ref_or_view[idx], val)

atomic_min(x_ref_or_view, idx, val, *[, mask])

原子性地计算 x_ref_or_view[idx] = min(x_ref_or_view[idx], val)

atomic_or(x_ref_or_view, idx, val, *[, mask])

原子性地计算 x_ref_or_view[idx] |= val

atomic_xchg(x_ref_or_view, idx, val, *[, mask])

原子性地将给定值与给定索引处的值进行交换。

debug_print(fmt, *args)

从 Pallas 内核内部打印标量值。