jax.experimental.pallas
模块#
用于Pallas的模块,Pallas是JAX的扩展,用于自定义内核。
请参阅 Pallas 文档,网址为 https://jax.readthedocs.io/en/latest/pallas.html。
类#
|
指定如何为每次内核调用对数组进行切片。 |
|
为 |
|
一个带有起始索引和大小的切片。 |
函数#
|
在某些输入上调用 Pallas 内核。 |
|
返回网格沿给定轴的内核执行位置。 |
|
返回网格沿给定轴的大小。 |
|
从给定的索引加载数组并返回。 |
|
在给定的索引处存储一个值。 |
|
交换给定索引处的值并返回旧值。 |
|
原子性地计算 |
|
原子性地计算 |
|
在引用中执行原子的比较和交换操作,将其值替换为给定的值。 |
|
原子性地计算 |
|
原子性地计算 |
|
原子性地计算 |
|
原子性地将给定值与给定索引处的值进行交换。 |
|
从 Pallas 内核内部打印标量值。 |