Pallas 更新日志#

这是特定于 jax.experimental.pallas 的变更列表。有关 JAX 的整体变更日志,请参见 此处

随 jax 0.4.32 发布#

  • 更改

    • 内核函数不允许关闭常量。相反,所有需要的数组必须作为输入传递,并带有适当的块规格(#22746)。

  • 弃用

  • 新功能:

    • 改进了索引映射函数签名中错误的错误消息,以包括索引映射的名称和源位置。

随 jax 0.4.31 版本发布(2024年7月29日)#

  • 更改

    • jax.experimental.pallas.BlockSpec 现在期望 block_shapeindex_map 之前 传递。旧的参数顺序已被弃用,并将在未来的版本中移除。

    • jax.experimental.pallas.GridSpec 不再有 in_specs_treeout_specs_tree 字段,并且 in_specsout_specs 树现在以 BlockSpec 的 pytrees 形式存储值。以前,in_specsout_specs 是扁平化的(#22552)。

    • compute_index 方法已从 jax.experimental.pallas.GridSpec 中移除,因为它是一个私有方法。同样地,get_grid_mappingunzip_dynamic_bounds 已从 BlockSpec 中移除(#22593)。

    • 修复了解释模式以处理涉及填充的 BlockSpec(#22275)。在解释模式下,填充将使用 NaN,以帮助调试越界错误,但在自定义内核模式下运行时,此行为不存在,不应依赖于此行为。

    • 之前可以导入许多本应为私有的API,如 jax.experimental.pallas.pallas。现在不再可能了。

  • 弃用

  • 新功能

    • 添加了 BlockSpec 的文档:网格和BlockSpecs

    • 改进了 jax.experimental.pallas.pallas_call() API 的错误信息。

    • 为 TPU 添加了 lax.shift_right_arithmetic (#22279) 和 lax.erf_inv (#22310) 的降低规则。

    • 为 Pallas TPU 自定义内核添加了对形状多态性的初步支持\ (#22084)。

    • 为 checkify 添加了 TPU 支持。(#22480)

    • 当块大小不符合TPU要求时,增加了更清晰的错误信息。以前,错误信息来自Mosaic后端,并且没有有用的Python堆栈跟踪。

    • 增加了对使用1D块的TPU降低的支持,并放宽了对至少2维块大小的要求:最后两个维度必须分别能被8和128整除,除非它们跨越整个相应的数组维度。以前,只有当最后两个维度的块维度分别小于8和128时,才允许跨越整个数组的块维度。

随 JAX 0.4.30 版本发布(2024年6月18日)#