多主机和多进程环境#
介绍#
本指南解释了如何在GPU集群和Cloud TPU pod等环境中使用JAX,在这些环境中,加速器分布在多个CPU主机或JAX进程上。我们将这些称为“多进程”环境。
本指南特别关注如何在多进程环境中使用集体通信操作(例如 jax.lax.psum()
),尽管根据您的使用情况,其他通信方法也可能有用(例如 RPC、mpi4jax)。如果您还不熟悉 JAX 的集体操作,我们建议从 并行编程简介 部分开始。JAX 中多进程环境的一个重要要求是加速器之间的直接通信链接,例如 Cloud TPU 的高速互连或 GPU 的 NCCL。这些链接允许集体操作在多个进程的加速器上以高性能运行。
多进程编程模型#
关键概念:
每个主机至少需要运行一个 JAX 进程。
你应该使用
jax.distributed.initialize()
初始化集群。每个进程都有一组独特的 本地 设备可以访问。全局 设备是所有进程中所有设备的集合。
使用标准的 JAX 并行 API,如
jit()
(参见 并行编程简介 教程)和shard_map()
。jax.jit 只接受全局形状的数组。shard_map 允许你切换到每个设备的形状。确保所有进程以相同的顺序运行相同的并行计算。
确保所有进程具有相同数量的本地设备。
确保所有设备相同(例如,所有 V100,或所有 H100)。
启动 JAX 进程#
与其他分布式系统不同,在其他系统中,单个控制节点管理多个工作节点,JAX 使用“多控制器”编程模型,其中每个 JAX Python 进程独立运行,有时称为 单程序,多数据 (SPMD) 模型。通常,相同的 JAX Python 程序在每个进程中运行,每个进程的执行之间只有细微的差异(例如,不同的进程将加载不同的输入数据)。此外,你必须手动在每个主机上运行你的 JAX 程序! JAX 不会自动从单个程序调用中启动多个进程。
(多个进程的需求是为什么本指南不作为笔记本提供——我们目前没有好的方法从单个笔记本管理多个Python进程。)
初始化集群#
要初始化集群,您应在每个进程的开始处调用 jax.distributed.initialize()
。jax.distributed.initialize()
必须在程序早期调用,在执行任何 JAX 计算之前。
API jax.distributed.initialize()
接受几个参数,即:
coordinator_address
: 集群中进程0的IP地址,以及该进程上可用的端口。进程0将通过该IP地址和端口启动一个JAX服务,集群中的其他进程将连接到该服务。coordinator_bind_address
: 集群中进程0上的JAX服务将绑定的IP地址和端口。默认情况下,它将绑定到所有可用接口,使用与coordinator_address
相同的端口。num_processes
: 集群中的进程数量process_id
: 此进程的ID号,范围为[0 .. num_processes)
。local_device_ids
: 将当前进程的可见设备限制为local_device_ids
。
例如在GPU上,典型用法是:
import jax
jax.distributed.initialize(coordinator_address="192.168.0.1:1234",
num_processes=2,
process_id=0)
在云TPU、Slurm和Open MPI环境中,您可以简单地调用 jax.distributed.initialize()
而不带任何参数。参数的默认值将自动选择。当在GPU上运行时,假设每个GPU启动一个进程,即每个进程将只分配一个可见的本地设备。否则,假设每个主机启动一个进程,即每个进程将分配所有本地设备。仅当通过 mpirun
/mpiexec
启动JAX进程时,才会使用Open MPI自动初始化。
import jax
jax.distributed.initialize()
在TPU上,目前调用 jax.distributed.initialize()
是可选的,但推荐这样做,因为它启用了额外的检查点和健康检查功能。
本地设备 vs. 全局设备#
在我们从程序中运行多进程计算之前,理解 本地 和 全局 设备之间的区别是很重要的。
一个进程的 本地 设备是那些它可以直接寻址并在其上启动计算的设备。 例如,在GPU集群中,每个主机只能在其直接连接的GPU上启动计算。在Cloud TPU pod中,每个主机只能在其直接连接的8个TPU核心上启动计算(更多详情请参阅Cloud TPU系统架构文档)。你可以通过 jax.local_devices()
查看一个进程的本地设备。
全局设备是跨所有进程的设备。 一个计算可以跨越进程间的设备,并通过设备之间的直接通信链接执行集体操作,只要每个进程在其本地设备上启动计算。你可以通过 jax.devices()
查看所有可用的全局设备。一个进程的本地设备总是全局设备的一个子集。
运行多进程计算#
那么,你如何实际运行涉及跨进程通信的计算呢?使用与单进程中相同的并行评估API!
例如,shard_map()
可以用于在多个进程中运行并行计算。(如果你还不熟悉如何使用 shard_map
在单个进程内的多个设备上运行,请查看 并行编程简介 教程。)从概念上讲,这可以被认为是运行一个在主机间分片的单个数组上的 pmap,其中每个主机“只看到”其本地输入和输出的分片。
以下是一个多进程 pmap 的实际应用示例:
# The following is run in parallel on each host on a GPU cluster or TPU pod slice.
>>> import jax
>>> jax.distributed.initialize() # On GPU, see above for the necessary arguments.
>>> jax.device_count() # total number of accelerator devices in the cluster
32
>>> jax.local_device_count() # number of accelerator devices attached to this host
8
# The psum is performed over all mapped devices across the pod slice
>>> xs = jax.numpy.ones(jax.local_device_count())
>>> jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)
ShardedDeviceArray([32., 32., 32., 32., 32., 32., 32., 32.], dtype=float32)
确保所有进程以相同的顺序运行相同的跨进程计算非常重要。 通常,在每个进程中运行相同的 JAX Python 程序就足够了。需要注意的一些常见陷阱可能会导致尽管运行相同的程序,但计算顺序不同:
将不同形状的输入传递给同一个并行函数可能会导致挂起或返回不正确的值。只要不同形状的输入在进程之间产生相同形状的每个设备数据分片,它们就是安全的;例如,传递不同的前导批次大小以在每个进程的不同数量的本地设备上运行是可以的,但让每个进程将其批次填充到不同的最大示例长度则不行。
“最后一组”问题,即在(训练)循环中调用并行函数,并且一个或多个进程比其他进程更早退出循环。这将导致其余进程挂起,等待已经完成的进程开始计算。
基于集合非确定性排序的条件可能导致代码进程挂起。例如,在当前Python版本中迭代
set
或在 Python 3.7之前 迭代dict
可能会导致在不同进程中出现不同的排序,即使插入顺序相同。