Pallas 更新日志#
这是特定于 jax.experimental.pallas
的变更列表。有关 JAX 的整体变更日志,请参见 此处。
随 jax 0.4.32 发布#
更改
内核函数不允许关闭常量。相反,所有需要的数组必须作为输入传递,并带有适当的块规格(#22746)。
弃用
新功能:
改进了索引映射函数签名中错误的错误消息,以包括索引映射的名称和源位置。
随 jax 0.4.31 版本发布(2024年7月29日)#
更改
jax.experimental.pallas.BlockSpec
现在期望block_shape
在index_map
之前 传递。旧的参数顺序已被弃用,并将在未来的版本中移除。jax.experimental.pallas.GridSpec
不再有in_specs_tree
和out_specs_tree
字段,并且in_specs
和out_specs
树现在以 BlockSpec 的 pytrees 形式存储值。以前,in_specs
和out_specs
是扁平化的(#22552)。compute_index
方法已从jax.experimental.pallas.GridSpec
中移除,因为它是一个私有方法。同样地,get_grid_mapping
和unzip_dynamic_bounds
已从BlockSpec
中移除(#22593)。修复了解释模式以处理涉及填充的 BlockSpec(#22275)。在解释模式下,填充将使用 NaN,以帮助调试越界错误,但在自定义内核模式下运行时,此行为不存在,不应依赖于此行为。
之前可以导入许多本应为私有的API,如
jax.experimental.pallas.pallas
。现在不再可能了。
弃用
新功能
添加了 BlockSpec 的文档:网格和BlockSpecs。
改进了
jax.experimental.pallas.pallas_call()
API 的错误信息。为 TPU 添加了
lax.shift_right_arithmetic
(#22279) 和lax.erf_inv
(#22310) 的降低规则。为 Pallas TPU 自定义内核添加了对形状多态性的初步支持\ (#22084)。
为 checkify 添加了 TPU 支持。(#22480)
当块大小不符合TPU要求时,增加了更清晰的错误信息。以前,错误信息来自Mosaic后端,并且没有有用的Python堆栈跟踪。
增加了对使用1D块的TPU降低的支持,并放宽了对至少2维块大小的要求:最后两个维度必须分别能被8和128整除,除非它们跨越整个相应的数组维度。以前,只有当最后两个维度的块维度分别小于8和128时,才允许跨越整个数组的块维度。
随 JAX 0.4.30 版本发布(2024年6月18日)#
新功能
在解释模式下为
jax.experimental.pallas.pallas_call()
添加了 checkify 支持(#21862)。改进了对TPU内核的PRNG键的支持(#21773)。