GPU 内存分配

GPU 内存分配#

JAX 在运行第一个 JAX 操作时将预分配 75% 的总 GPU 内存。 预分配可以减少分配开销和内存碎片,但有时会导致内存不足 (OOM) 错误。如果你的 JAX 进程因 OOM 失败,可以使用以下环境变量来覆盖默认行为:

XLA_PYTHON_CLIENT_PREALLOCATE=false

这将禁用预分配行为。JAX 将改为根据需要分配 GPU 内存,可能会减少总体内存使用量。然而,这种行为更容易导致 GPU 内存碎片化,这意味着使用大部分可用 GPU 内存的 JAX 程序在禁用预分配时可能会出现内存不足(OOM)的情况。

XLA_PYTHON_CLIENT_MEM_FRACTION=.XX

如果启用了预分配,这将使 JAX 预分配 XX% 的总 GPU 内存,而不是默认的 75%。减少预分配的内存量可以修复 JAX 程序启动时发生的 OOM 问题。

XLA_PYTHON_CLIENT_ALLOCATOR=platform

这使得JAX按需分配所需资源,并释放不再需要的内存(注意,这是唯一会释放GPU内存而不是重用它的配置)。这非常慢,因此不推荐常规使用,但对于以最小可能的GPU内存占用运行或调试OOM故障可能很有用。

OOM 失败的常见原因#

同时运行多个 JAX 进程。

可以使用 XLA_PYTHON_CLIENT_MEM_FRACTION 为每个进程分配适当的内存,或者设置 XLA_PYTHON_CLIENT_PREALLOCATE=false

同时运行 JAX 和 GPU TensorFlow。

TensorFlow 也默认预分配内存,因此这与同时运行多个 JAX 进程类似。

一种解决方案是使用仅CPU的TensorFlow(例如,如果你只使用TF进行数据加载)。你可以通过命令 tf.config.experimental.set_visible_devices([], "GPU") 阻止TensorFlow使用GPU。

或者,使用 XLA_PYTHON_CLIENT_MEM_FRACTIONXLA_PYTHON_CLIENT_PREALLOCATE。还有一些类似的选项用于配置 TensorFlow 的 GPU 内存分配(在 TF1 中为 gpu_memory_fractionallow_growth,这些应在传递给 tf.Sessiontf.ConfigProto 中设置。参见 TF2 的 使用 GPU:限制 GPU 内存增长)。

在显示GPU上运行JAX。

使用 XLA_PYTHON_CLIENT_MEM_FRACTIONXLA_PYTHON_CLIENT_PREALLOCATE