Jax 和 Jaxlib 版本控制#

为什么 jaxjaxlib 是分开的包?#

我们以两个独立的Python轮子发布JAX,即 jax,这是一个纯Python轮子,以及 jaxlib,这是一个主要由C++编写的轮子,包含以下库:

  • XLA,

  • XLA 使用的 LLVM 组件,

  • MLIR 基础设施,例如 StableHLO Python 绑定。

  • 用于快速JIT和PyTree操作的JAX特定C++库。

我们分发独立的 jaxjaxlib 包,因为这使得在不构建 C++ 代码或甚至不安装 C++ 工具链的情况下,也能轻松处理 JAX 的 Python 部分。jaxlib 是一个大型库,对许多用户来说不容易构建,但大多数对 JAX 的更改只涉及 Python 代码。通过允许 Python 部分独立于 C++ 部分更新,我们提高了 Python 更改的开发速度。

此外,jaxlib 的构建成本不低,但我们希望能够在 CPU 资源不多的环境中迭代和运行 JAX 测试,例如在 Github Actions 或笔记本电脑上。我们的许多 CI 构建只是使用预构建的 jaxlib,而不是需要在每个 PR 上重新构建 JAX 的 C++ 部分。

正如我们将看到的,单独分发 jaxjaxlib 是有代价的,因为这要求对 jaxlib 的更改保持向后兼容的 API。然而,我们认为总体上更容易进行 Python 更改是可取的,即使这使得 C++ 更改稍微困难一些。

jaxjaxlib 是如何版本化的?#

摘要:在 JAX 源码树中,jaxjaxlib 共享相同的版本号,但作为独立的 Python 包发布。安装时,jax 包的版本必须大于或等于 jaxlib 的版本,并且 jaxlib 的版本必须大于或等于 jax 指定的最小 jaxlib 版本。

jaxjaxlib 的发布版本号均为 x.y.z,其中 x 是主版本号,y 是次版本号,z 是可选的补丁发布版本号。版本号必须遵循 PEP 440。版本号的比较是基于整数元组的字典序比较。

每个 jax 版本都有一个关联的最小 jaxlib 版本 mx.my.mzjax 版本 x.y.z 的最小 jaxlib 版本必须不大于 x.y.z

为了 jax 版本 x.y.zjaxlib 版本 lx.ly.lz 兼容,以下条件必须成立:

  • jaxlib 版本 (lx.ly.lz) 必须大于或等于最小 jaxlib 版本 (mx.my.mz)。

  • jax 版本 (x.y.z) 必须大于或等于 jaxlib 版本 (lx.ly.lz)。

这些约束意味着以下发布规则:

  • jax 可能会在任何时候独立发布,而不更新 jaxlib

  • 如果发布了新的 jaxlib,则必须同时发布 jax

这些 版本约束 目前由 jax 在导入时检查,而不是作为 Python 包版本约束来表达。jax 在运行时检查 jaxlib 版本,而不是使用 pip 包版本约束,因为我们 为各种硬件和软件版本(例如,GPU、TPU 等)提供了单独的 jaxlib 轮子。由于我们不知道哪个选项对任何特定用户是正确的,我们不希望 pip 自动为我们安装 jaxlib 包。

在未来,我们希望将 jaxlib 中特定于硬件的部分分离到单独的插件中,届时可以将最低版本表示为 Python 包依赖项。目前,我们确实提供了特定平台的额外要求,这些要求安装了兼容的 jaxlib 版本,例如 jax[cuda]

如何安全地对 jaxlib 的 API 进行更改?#

  • jax 可能会随时放弃对旧版 jaxlib 的兼容性,只要将最低 jaxlib 版本提升至兼容版本即可。然而,请注意,即使是 jax 的未发布版本,其最低 jaxlib 版本也必须是一个已发布的版本!这使得我们可以在 CI 构建中使用已发布的 jaxlib 轮子,并允许 Python 开发者在 HEAD 上开发 jax 时,无需构建 jaxlib

    例如,要在 jax Python 代码中移除旧的向后兼容路径,只需提升最小 jaxlib 版本,然后删除兼容路径即可。

  • jaxlib 可能会放弃对低于其自身发布版本号的旧 jax 版本的兼容性。jax 施加的版本约束将禁止使用不兼容的 jaxlib

    例如,为了使 jaxlib 放弃由旧版 jax 使用的 Python 绑定 API,必须增加 jaxlib 的次版本号或主版本号。

  • 如果可能,对 jaxlib 的更改应以向后兼容的方式进行。

    一般来说,jaxlib 可以自由更改其 API,只要遵循 jax 与至少是最低版本的 jaxlib 兼容的规则。这意味着 jax 必须始终与至少两个版本的 jaxlib 兼容,即上一个发布版本和树尖版本(实际上是下一个发布版本)。如果保持兼容性,这样做会更容易,尽管可以通过 jax 的版本测试来实现不兼容的更改;见下文。

    例如,通常可以安全地向 jaxlib 添加一个新函数,但如果当前 jax 仍在使用该函数,则删除现有函数或更改其签名是不安全的。对 jax 的更改必须适用于所有大于最小版本的 jaxlib 版本,直至 HEAD。

请注意,这里的兼容性规则仅适用于 jaxjaxlib已发布 版本。它们不适用于未发布的版本;也就是说,如果在 jaxlib 中引入然后移除一个API是可以接受的,前提是该API从未发布,或者没有已发布的 jax 版本使用该API。

jaxlib 的源码是如何组织的?#

jaxlib 分布在两个主要仓库中,即主 JAX 仓库中的 jaxlib/ 子目录XLA 源代码树,位于 XLA 仓库内。XLA 中特定于 JAX 的部分主要在 xla/python 子目录 中。

C++ 的 JAX 部分,如 Python 绑定和运行时组件,位于 XLA 树内的原因部分是历史性的,部分是技术性的。

历史原因是,最初 xla/python 绑定被设想为通用目的的 Python 绑定,可能会与其他框架共享。实际上,这种情况越来越少见,xla/python 包含了许多 JAX 特定的部分,并且可能会包含更多。因此,最好简单地将 xla/python 视为 JAX 的一部分。

技术原因是 XLA C++ API 不稳定。通过将 XLA:Python 绑定保留在 XLA 树中,它们的 C++ 实现可以与 XLA 的 C++ API 一起原子更新。维护 Python API 的向后和向前兼容性比 C++ API 更容易,因此 xla/python 暴露了 Python API,并负责在 Python 级别维护向后兼容性。

jaxlib 是使用 Bazel 从 jax 仓库构建的。来自 XLA 仓库的 jaxlib 部分被整合到构建中 作为 Bazel 子模块。要更新构建过程中使用的 XLA 版本,必须更新 Bazel WORKSPACE 中的固定版本。这是根据需要手动完成的,但可以在每次构建的基础上覆盖。

我们如何在发布之间跨越 jaxjaxlib 的边界进行更改?#

jaxlib 版本是一个粗略的工具:它只让我们能够推理 发布

然而,由于 jaxjaxlib 代码分布在无法通过单次更改原子更新的不同仓库中,我们需要在比发布周期更细的粒度上管理兼容性。为了管理细粒度的兼容性,我们有额外的版本控制,这与 jaxlib 的发布版本号无关。

我们在 xla_client.py 文件中维护了一个额外的版本号(_version),该文件位于 XLA 仓库中。这个版本号与 JAX 的 C++ 部分一起定义在 xla/python 中,并且可以通过 jax._src.lib.xla_extension_version 在 JAX Python 中访问。每次对 XLA/Python 代码进行更改时,如果该更改对 jax 有向后兼容性影响,就必须增加这个版本号。然后,JAX Python 代码可以使用这个版本号来维护向后兼容性,例如:

from jax._src.lib import xla_extension_version

# 123 is the new version number for _version in xla_client.py
if xla_extension_version >= 123:
  # Use new code path
  ...
else:
  # Use old code path.

请注意,此版本号是 附加 于已发布版本号的约束之上,也就是说,此版本号的存在是为了帮助在开发过程中管理未发布代码的兼容性。发布版本也必须遵循上述兼容性规则。