Jax 和 Jaxlib 版本控制#
为什么 jax 和 jaxlib 是分开的包?#
我们以两个独立的Python轮子发布JAX,即 jax,这是一个纯Python轮子,以及 jaxlib,这是一个主要由C++编写的轮子,包含以下库:
XLA,
XLA 使用的 LLVM 组件,
MLIR 基础设施,例如 StableHLO Python 绑定。
用于快速JIT和PyTree操作的JAX特定C++库。
我们分发独立的 jax 和 jaxlib 包,因为这使得在不构建 C++ 代码或甚至不安装 C++ 工具链的情况下,也能轻松处理 JAX 的 Python 部分。jaxlib 是一个大型库,对许多用户来说不容易构建,但大多数对 JAX 的更改只涉及 Python 代码。通过允许 Python 部分独立于 C++ 部分更新,我们提高了 Python 更改的开发速度。
此外,jaxlib 的构建成本不低,但我们希望能够在 CPU 资源不多的环境中迭代和运行 JAX 测试,例如在 Github Actions 或笔记本电脑上。我们的许多 CI 构建只是使用预构建的 jaxlib,而不是需要在每个 PR 上重新构建 JAX 的 C++ 部分。
正如我们将看到的,单独分发 jax 和 jaxlib 是有代价的,因为这要求对 jaxlib 的更改保持向后兼容的 API。然而,我们认为总体上更容易进行 Python 更改是可取的,即使这使得 C++ 更改稍微困难一些。
jax 和 jaxlib 是如何版本化的?#
摘要:在 JAX 源码树中,jax 和 jaxlib 共享相同的版本号,但作为独立的 Python 包发布。安装时,jax 包的版本必须大于或等于 jaxlib 的版本,并且 jaxlib 的版本必须大于或等于 jax 指定的最小 jaxlib 版本。
jax 和 jaxlib 的发布版本号均为 x.y.z,其中 x 是主版本号,y 是次版本号,z 是可选的补丁发布版本号。版本号必须遵循 PEP 440。版本号的比较是基于整数元组的字典序比较。
每个 jax 版本都有一个关联的最小 jaxlib 版本 mx.my.mz。jax 版本 x.y.z 的最小 jaxlib 版本必须不大于 x.y.z。
为了 jax 版本 x.y.z 和 jaxlib 版本 lx.ly.lz 兼容,以下条件必须成立:
jaxlib 版本 (
lx.ly.lz) 必须大于或等于最小 jaxlib 版本 (mx.my.mz)。jax 版本 (
x.y.z) 必须大于或等于 jaxlib 版本 (lx.ly.lz)。
这些约束意味着以下发布规则:
jax可能会在任何时候独立发布,而不更新jaxlib。如果发布了新的
jaxlib,则必须同时发布jax。
这些 版本约束 目前由 jax 在导入时检查,而不是作为 Python 包版本约束来表达。jax 在运行时检查 jaxlib 版本,而不是使用 pip 包版本约束,因为我们 为各种硬件和软件版本(例如,GPU、TPU 等)提供了单独的 jaxlib 轮子。由于我们不知道哪个选项对任何特定用户是正确的,我们不希望 pip 自动为我们安装 jaxlib 包。
在未来,我们希望将 jaxlib 中特定于硬件的部分分离到单独的插件中,届时可以将最低版本表示为 Python 包依赖项。目前,我们确实提供了特定平台的额外要求,这些要求安装了兼容的 jaxlib 版本,例如 jax[cuda]。
如何安全地对 jaxlib 的 API 进行更改?#
jax可能会随时放弃对旧版jaxlib的兼容性,只要将最低jaxlib版本提升至兼容版本即可。然而,请注意,即使是jax的未发布版本,其最低jaxlib版本也必须是一个已发布的版本!这使得我们可以在 CI 构建中使用已发布的jaxlib轮子,并允许 Python 开发者在 HEAD 上开发jax时,无需构建jaxlib。例如,要在
jaxPython 代码中移除旧的向后兼容路径,只需提升最小 jaxlib 版本,然后删除兼容路径即可。jaxlib可能会放弃对低于其自身发布版本号的旧jax版本的兼容性。jax施加的版本约束将禁止使用不兼容的jaxlib。例如,为了使
jaxlib放弃由旧版jax使用的 Python 绑定 API,必须增加jaxlib的次版本号或主版本号。如果可能,对
jaxlib的更改应以向后兼容的方式进行。一般来说,
jaxlib可以自由更改其 API,只要遵循jax与至少是最低版本的jaxlib兼容的规则。这意味着jax必须始终与至少两个版本的jaxlib兼容,即上一个发布版本和树尖版本(实际上是下一个发布版本)。如果保持兼容性,这样做会更容易,尽管可以通过jax的版本测试来实现不兼容的更改;见下文。例如,通常可以安全地向
jaxlib添加一个新函数,但如果当前jax仍在使用该函数,则删除现有函数或更改其签名是不安全的。对jax的更改必须适用于所有大于最小版本的jaxlib版本,直至 HEAD。
请注意,这里的兼容性规则仅适用于 jax 和 jaxlib 的 已发布 版本。它们不适用于未发布的版本;也就是说,如果在 jaxlib 中引入然后移除一个API是可以接受的,前提是该API从未发布,或者没有已发布的 jax 版本使用该API。
jaxlib 的源码是如何组织的?#
jaxlib 分布在两个主要仓库中,即主 JAX 仓库中的 jaxlib/ 子目录 和 XLA 源代码树,位于 XLA 仓库内。XLA 中特定于 JAX 的部分主要在 xla/python 子目录 中。
C++ 的 JAX 部分,如 Python 绑定和运行时组件,位于 XLA 树内的原因部分是历史性的,部分是技术性的。
历史原因是,最初 xla/python 绑定被设想为通用目的的 Python 绑定,可能会与其他框架共享。实际上,这种情况越来越少见,xla/python 包含了许多 JAX 特定的部分,并且可能会包含更多。因此,最好简单地将 xla/python 视为 JAX 的一部分。
技术原因是 XLA C++ API 不稳定。通过将 XLA:Python 绑定保留在 XLA 树中,它们的 C++ 实现可以与 XLA 的 C++ API 一起原子更新。维护 Python API 的向后和向前兼容性比 C++ API 更容易,因此 xla/python 暴露了 Python API,并负责在 Python 级别维护向后兼容性。
jaxlib 是使用 Bazel 从 jax 仓库构建的。来自 XLA 仓库的 jaxlib 部分被整合到构建中 作为 Bazel 子模块。要更新构建过程中使用的 XLA 版本,必须更新 Bazel WORKSPACE 中的固定版本。这是根据需要手动完成的,但可以在每次构建的基础上覆盖。
我们如何在发布之间跨越 jax 和 jaxlib 的边界进行更改?#
jaxlib 版本是一个粗略的工具:它只让我们能够推理 发布。
然而,由于 jax 和 jaxlib 代码分布在无法通过单次更改原子更新的不同仓库中,我们需要在比发布周期更细的粒度上管理兼容性。为了管理细粒度的兼容性,我们有额外的版本控制,这与 jaxlib 的发布版本号无关。
我们在 xla_client.py 文件中维护了一个额外的版本号(_version),该文件位于 XLA 仓库中。这个版本号与 JAX 的 C++ 部分一起定义在 xla/python 中,并且可以通过 jax._src.lib.xla_extension_version 在 JAX Python 中访问。每次对 XLA/Python 代码进行更改时,如果该更改对 jax 有向后兼容性影响,就必须增加这个版本号。然后,JAX Python 代码可以使用这个版本号来维护向后兼容性,例如:
from jax._src.lib import xla_extension_version
# 123 is the new version number for _version in xla_client.py
if xla_extension_version >= 123:
# Use new code path
...
else:
# Use old code path.
请注意,此版本号是 附加 于已发布版本号的约束之上,也就是说,此版本号的存在是为了帮助在开发过程中管理未发布代码的兼容性。发布版本也必须遵循上述兼容性规则。