jax.distributed.initialize#
- jax.distributed.initialize(coordinator_address=None, num_processes=None, process_id=None, local_device_ids=None, cluster_detection_method=None, initialization_timeout=300, coordinator_bind_address=None)[源代码][源代码]#
初始化 JAX 分布式系统。
调用
initialize()
为在多主机GPU和云TPU上执行JAX做好准备。在进行任何JAX计算之前,必须调用initialize()
。JAX 分布式系统承担了多个角色:
它允许 JAX 进程相互发现并共享拓扑信息,
它执行健康检查,确保如果任何进程死亡,所有进程都会关闭,并且
它用于分布式检查点。
如果你使用的是 TPU、Slurm 或 Open MPI,所有参数都是可选的:如果省略,它们将自动选择。
cluster_detection_method
可用于选择特定的方法来检测这些分布式参数。你可以将任何自动的spec_detect_methods
传递给这个参数,尽管在 TPU、Slurm 或 Open MPI 的情况下这不是必需的。对于其他 MPI 安装,如果你安装了功能正常的mpi4py
,你可以传递cluster_detection_method="mpi4py"
来引导所需的参数。否则,您必须提供
coordinator_address
、num_processes
和process_id
参数给initialize()
。请注意:在某些系统上,特别是只能通过HTTP_PROXY、HTTPS_PROXY等代理变量访问外部网络的高性能计算集群,调用
initialize()
可能会超时。您可能需要在应用程序启动前取消设置这些变量。- 参数:
coordinator_address (str | None) – 进程 0 的 IP 地址和该进程应启动协调服务器的端口。端口的选择无关紧要,只要协调器和所有进程都同意该端口即可。仅在支持的环境中可能为
None
,在这种情况下将自动选择。请注意,像localhost
或127.0.0.1
这样的特殊地址通常意味着程序将绑定到本地接口,在多主机环境中运行时并不适用。num_processes (int | None) – 进程数量。仅在支持的环境中可以为
None
,在这种情况下,它将自动选择。process_id (int | None) – 当前进程的ID编号。集群中所有
process_id
的值必须是一个密集的范围0
,1
, …,num_processes - 1
。仅在支持的环境中可能为None
;如果为None
,它将自动选择。local_device_ids (int | Sequence[int] | None) – 将当前进程的可见设备限制为
local_device_ids
。如果为None
,则默认所有本地设备对进程可见,除非通过 Slurm 和 Open MPI 在 GPU 上启动进程。在这种情况下,它将默认为每个进程一个设备。cluster_detection_method (str | None) – 一个可选的字符串,用于尝试自动检测分布式运行的配置。请注意,“mpi4py”方法要求你在环境中安装一个可用的
mpi4py
,并使用如mpiexec
或mpirun
等兼容MPI的任务启动器启动应用程序。遗留的自动检测选项(OMPI、Slurm)仍然启用。initialization_timeout (int) – 连接重试的时间段(以秒为单位)。如果初始化时间超过指定的超时时间,初始化将出错。默认为300秒,即5分钟。
coordinator_bind_address (str | None) – 协调服务在进程 0 上应绑定的地址和端口。如果未指定,默认情况下会绑定到所有可用地址,端口与
coordinator_address
相同。在每个节点有多个网络接口的系统上,仅让协调服务监听一个地址/接口可能是不够的。
- 抛出:
RuntimeError – 如果
initialize()
被调用多次,或者在后台已经初始化后调用。
示例:
假设有两个GPU进程,进程0是指定的协调器,地址为
10.0.0.1:1234
。在执行任何其他操作之前,运行以下命令以初始化GPU集群。在进程 0 上:
>>> jax.distributed.initialize(coordinator_address='10.0.0.1:1234', num_processes=2, process_id=0)
在进程 1 上:
>>> jax.distributed.initialize(coordinator_address='10.0.0.1:1234', num_processes=2, process_id=1)