安装#
使用 JAX 需要安装两个包:jax
,这是一个纯 Python 且跨平台的包,以及 jaxlib
,它包含编译后的二进制文件,并且需要针对不同的操作系统和加速器进行不同的构建。
总结: 对于大多数用户来说,一个典型的 JAX 安装可能看起来像这样:
仅CPU(Linux/macOS/Windows)
pip install -U jax
GPU (NVIDIA, CUDA 12)
pip install -U "jax[cuda12]"
TPU (Google Cloud TPU VM)
pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
支持的平台#
下表显示了所有支持的平台和安装选项。检查您的设置是否受支持;如果显示为“是”或“实验性”,请点击相应的链接以了解如何更详细地安装JAX。
CPU#
pip 安装:CPU#
目前,JAX 团队为以下操作系统和架构发布了 jaxlib
轮子:
Linux, x86_64
Linux, aarch64
macOS, Intel
macOS, Apple ARM 架构
Windows, x86_64 (实验性)
要安装仅支持CPU版本的JAX,这在笔记本电脑上进行本地开发时可能很有用,你可以运行:
pip install --upgrade pip
pip install --upgrade jax
在 Windows 上,如果您的机器上尚未安装,您可能还需要安装 Microsoft Visual Studio 2019 Redistributable。
其他操作系统和架构需要从源代码构建。尝试在其他操作系统和架构上使用 pip 安装可能会导致 jaxlib
没有与 jax
一起安装,尽管 jax
可能成功安装(但在运行时失败)。
NVIDIA GPU#
JAX 支持具有 SM 版本 5.2(Maxwell)或更新的 NVIDIA GPU。请注意,由于 NVIDIA 已在其软件中放弃对 Kepler GPU 的支持,JAX 不再支持 Kepler 系列 GPU。
您必须首先安装 NVIDIA 驱动程序。建议您安装 NVIDIA 提供的最新驱动程序,但对于 Linux 上的 CUDA 12,驱动程序版本必须 >= 525.60.13。
如果你需要在一个较旧的驱动程序上使用较新的CUDA工具包,例如在一个你无法轻松更新NVIDIA驱动的集群上,你可能可以使用NVIDIA为此目的提供的CUDA向前兼容包。
pip 安装:NVIDIA GPU(CUDA,通过 pip 安装,更简单)#
有两种方法可以安装支持NVIDIA GPU的JAX:
使用从 pip 轮安装的 NVIDIA CUDA 和 cuDNN
使用自安装的 CUDA/cuDNN
JAX 团队强烈建议使用 pip 轮子安装 CUDA 和 cuDNN,因为这样要容易得多!
NVIDIA 仅发布了适用于 x86_64 和 aarch64 的 CUDA pip 包;在其他平台上,您必须使用本地的 CUDA 安装。
pip install --upgrade pip
# NVIDIA CUDA 12 installation
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda12]"
如果 JAX 检测到 NVIDIA CUDA 库的错误版本,您需要检查以下几项:
确保
LD_LIBRARY_PATH
未设置,因为LD_LIBRARY_PATH
可能会覆盖 NVIDIA CUDA 库。确保安装的 NVIDIA CUDA 库是 JAX 所要求的。重新运行上述安装命令应该可以解决问题。
pip 安装:NVIDIA GPU(CUDA,本地安装,较难)#
如果你倾向于使用预装的 NVIDIA CUDA 版本,你必须首先安装 NVIDIA CUDA 和 cuDNN。
JAX 为 Linux x86_64 和 Linux aarch64 提供了预构建的 CUDA 兼容轮子。其他操作系统和架构的组合也是可能的,但需要从源代码构建(参考 从源代码构建 了解更多)。
你应该使用至少与 NVIDIA CUDA 工具包对应的驱动版本 一样新的 NVIDIA 驱动版本。如果你需要使用较新的 CUDA 工具包与较旧的驱动程序,例如在集群上,你无法轻松更新 NVIDIA 驱动程序,你可能可以使用 NVIDIA 为此目的提供的 CUDA 向前兼容包。
JAX 目前提供一种 CUDA 轮子变体:
构建于 |
兼容于 |
---|---|
CUDA 12.3 |
CUDA >=12.1 |
CUDNN 9.1 |
CUDNN >=9.1, <10.0 |
NCCL 2.19 |
NCCL >=2.18 |
JAX 会检查你的库版本,如果它们不够新,将会报告错误。设置 JAX_SKIP_CUDA_CONSTRAINTS_CHECK
环境变量将禁用检查,但使用较旧版本的 CUDA 可能会导致错误或不正确的结果。
NCCL 是一个可选依赖项,仅在执行多GPU计算时需要。
要安装,请运行:
pip install --upgrade pip
# Installs the wheel compatible with NVIDIA CUDA 12 and cuDNN 9.0 or newer.
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda12_local]"
这些 pip
安装不适用于 Windows,并且可能会静默失败;请参阅上方的表格 above。
你可以使用以下命令找到你的CUDA版本:
nvcc --version
JAX 使用 LD_LIBRARY_PATH
来查找 CUDA 库,使用 PATH
来查找二进制文件(ptxas
, nvlink
)。请确保这些路径指向正确的 CUDA 安装。
JAX 需要 libdevice10.bc,这通常来自 cuda-nvvm 包。请确保它在您的 CUDA 安装中存在。
如果在使用预构建的轮子时遇到任何错误或问题,请在 GitHub 问题跟踪器 上通知 JAX 团队。
NVIDIA GPU Docker 容器#
NVIDIA 提供了 JAX Toolbox 容器,这些是最前沿的容器,包含 jax 的夜间发布版本以及一些模型/框架。
Google Cloud TPU#
pip 安装:Google Cloud TPU#
JAX 为 Google Cloud TPU 提供了预构建的轮子。要在云 TPU VM 中安装 JAX 以及适当版本的 jaxlib
和 libtpu
,您可以在您的云 TPU VM 中运行以下命令:
pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
对于使用 Colab (https://colab.research.google.com/) 的用户,请确保您使用的是 TPU v2,而不是旧的、已弃用的 TPU 运行时。
Apple Silicon GPU (基于ARM架构)#
pip 安装:基于 Apple ARM 架构的硅基 GPU#
Apple 为基于 ARM 的 Apple GPU 硬件提供了一个实验性的 Metal 插件。详情请参阅 Apple 的 JAX on Metal 文档。
注意: Metal 插件有几个注意事项:
Metal 插件是新的且实验性的,并且存在一些 已知问题。请在 JAX 问题跟踪器上报告任何问题。
Metal 插件目前需要非常特定的
jax
和jaxlib
版本。随着插件 API 的成熟,这一限制将逐渐放宽。
AMD GPU#
JAX 有实验性的 ROCm 支持。安装 JAX 有两种方式:
使用 AMD 的 Docker 容器;或
从源代码构建(参考 从源代码构建 — 一个名为 为 AMD GPU 构建 ROCM
jaxlib
的附加说明 的部分)。
Conda (社区支持)#
Conda 安装#
有一个社区支持的 jax
的 Conda 构建。要使用 conda
安装它,只需运行:
conda install jax -c conda-forge
要在配备NVIDIA GPU的机器上安装它,请运行:
conda install jaxlib=*=*cuda* jax cuda-nvcc -c conda-forge -c nvidia
注意 conda-forge
分发的 cudatoolkit
缺少 JAX 所需的 ptxas
。因此,您必须从 nvidia
频道安装 cuda-nvcc
包,或者单独在您的机器上安装 CUDA,以便 ptxas
在您的路径中。上述频道的顺序很重要(conda-forge
在 nvidia
之前)。
如果你想覆盖JAX使用的CUDA版本,或者在没有GPU的机器上安装CUDA构建,请按照 conda-forge
网站的 技巧与窍门 部分的说明进行操作。
JAX 夜间安装#
每晚发布的版本反映了构建时 JAX 主仓库的状态,可能未通过完整的测试套件。
与安装 JAX 版本的指令不同,这里我们在命令行中明确列出了 JAX 的所有包,因此如果存在更新的版本,pip
将会升级它们。
仅限CPU:
pip install -U --pre jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
Google Cloud TPU:
pip install -U --pre jax[tpu] jaxlib libtpu-nightly -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
NVIDIA GPU (CUDA 12):
pip install -U --pre jax[cuda12] jaxlib jax-cuda12-plugin jax-cuda12-pjrt -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
NVIDIA GPU (CUDA 12) 传统版:
使用以下内容获取单体 CUDA jaxlibs 的历史夜间版本。您很可能不需要这个;不会再构建更多的单体 CUDA jaxlibs,并且现有的将在 2024 年 9 月之前过期。请使用上面的“CUDA 12”选项。
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html
从源码构建 JAX#
参考 从源码构建。
安装旧版 jaxlib
轮子#
由于Python包索引上的存储限制,JAX团队会定期从http://pypi.org/project/jax的发布中移除旧的jaxlib
轮子。这些仍然可以通过这里的URL直接安装。例如:
# Install jaxlib on CPU via the wheel archive
pip install jax[cpu]==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_releases.html
# Install the jaxlib 0.3.25 CPU wheel directly
pip install jaxlib==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_releases.html
对于特定的旧版GPU轮子,请确保使用 jax_cuda_releases.html
URL;例如
pip install jaxlib==0.3.25+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html