GPU 内存分配#
JAX 在运行第一个 JAX 操作时将预分配 75% 的总 GPU 内存。 预分配可以减少分配开销和内存碎片,但有时会导致内存不足 (OOM) 错误。如果你的 JAX 进程因 OOM 失败,可以使用以下环境变量来覆盖默认行为:
XLA_PYTHON_CLIENT_PREALLOCATE=false
这将禁用预分配行为。JAX 将改为根据需要分配 GPU 内存,可能会减少总体内存使用量。然而,这种行为更容易导致 GPU 内存碎片化,这意味着使用大部分可用 GPU 内存的 JAX 程序在禁用预分配时可能会出现内存不足(OOM)的情况。
XLA_PYTHON_CLIENT_MEM_FRACTION=.XX
如果启用了预分配,这将使 JAX 预分配 XX% 的总 GPU 内存,而不是默认的 75%。减少预分配的内存量可以修复 JAX 程序启动时发生的 OOM 问题。
XLA_PYTHON_CLIENT_ALLOCATOR=platform
这使得JAX按需分配所需资源,并释放不再需要的内存(注意,这是唯一会释放GPU内存而不是重用它的配置)。这非常慢,因此不推荐常规使用,但对于以最小可能的GPU内存占用运行或调试OOM故障可能很有用。
OOM 失败的常见原因#
- 同时运行多个 JAX 进程。
可以使用
XLA_PYTHON_CLIENT_MEM_FRACTION
为每个进程分配适当的内存,或者设置XLA_PYTHON_CLIENT_PREALLOCATE=false
。- 同时运行 JAX 和 GPU TensorFlow。
TensorFlow 也默认预分配内存,因此这与同时运行多个 JAX 进程类似。
一种解决方案是使用仅CPU的TensorFlow(例如,如果你只使用TF进行数据加载)。你可以通过命令
tf.config.experimental.set_visible_devices([], "GPU")
阻止TensorFlow使用GPU。或者,使用
XLA_PYTHON_CLIENT_MEM_FRACTION
或XLA_PYTHON_CLIENT_PREALLOCATE
。还有一些类似的选项用于配置 TensorFlow 的 GPU 内存分配(在 TF1 中为 gpu_memory_fraction 和 allow_growth,这些应在传递给tf.Session
的tf.ConfigProto
中设置。参见 TF2 的 使用 GPU:限制 GPU 内存增长)。- 在显示GPU上运行JAX。
使用
XLA_PYTHON_CLIENT_MEM_FRACTION
或XLA_PYTHON_CLIENT_PREALLOCATE
。