分析设备内存#
备注
2023年5月更新:我们推荐使用 Tensorboard 分析 进行设备内存分析。在完成分析后,打开 Tensorboard 分析器的 memory_viewer
标签以获取更详细且易于理解的设备内存使用情况。
JAX 设备内存分析器允许我们探索 JAX 程序如何以及为何使用 GPU 或 TPU 内存。例如,它可以用于:
找出在给定时间哪些数组和可执行文件在GPU内存中,或者
追踪内存泄漏。
安装#
JAX 设备内存分析器发出的输出可以使用 pprof(google/pprof)进行解释。首先按照其 安装说明 安装 pprof
。在撰写本文时,安装 pprof
需要先安装 1.16 版本以上的 Go,Graphviz,然后运行
go install github.com/google/pprof@latest
这将安装 pprof
到 $GOPATH/bin/pprof
,其中 GOPATH
默认为 ~/go
。
备注
来自 google/pprof 的 pprof
版本与作为 gperftools
包一部分分发的同名旧工具不同。gperftools
版本的 pprof
将无法与 JAX 一起使用。
理解 JAX 程序如何使用 GPU 或 TPU 内存#
设备内存分析器的一个常见用途是找出为什么 JAX 程序使用了大量的 GPU 或 TPU 内存,例如在尝试调试内存不足问题时。
要捕获设备内存配置文件到磁盘,请使用 jax.profiler.save_device_memory_profile()
。例如,考虑以下 Python 程序:
import jax
import jax.numpy as jnp
import jax.profiler
def func1(x):
return jnp.tile(x, 10) * 0.5
def func2(x):
y = func1(x)
return y, jnp.tile(x, 10) + 1
x = jax.random.normal(jax.random.key(42), (1000, 1000))
y, z = func2(x)
z.block_until_ready()
jax.profiler.save_device_memory_profile("memory.prof")
如果我们先运行上面的程序,然后再执行
pprof --web memory.prof
pprof
打开一个包含以下调用图格式的设备内存概要可视化的网页浏览器:
调用图是对Python堆栈在分配每个活动缓冲区时的可视化。例如,在这种情况下,可视化显示func2
及其被调用者负责分配了76.30MB,其中38.15MB是在从func1
调用func2
时分配的。有关如何解释调用图可视化的更多信息,请参阅pprof文档。
使用 jax.jit()
编译的函数对设备内存分析器是不透明的。也就是说,在 jit
编译的函数内部分配的任何内存都将归因于整个函数。
在示例中,调用 block_until_ready()
是为了确保在收集设备内存配置文件之前 func2
完成。更多详情请参见 异步分发。
调试内存泄漏#
我们还可以使用 JAX 设备内存分析器,通过使用 pprof
来可视化在不同时间点获取的两个设备内存配置文件之间的内存使用变化,从而追踪内存泄漏。例如,考虑以下程序,该程序将 JAX 数组累积到一个不断增长的 Python 列表中。
import jax
import jax.numpy as jnp
import jax.profiler
def afunction():
return jax.random.normal(jax.random.key(77), (1000000,))
z = afunction()
def anotherfunc():
arrays = []
for i in range(1, 10):
x = jax.random.normal(jax.random.key(42), (i, 10000))
arrays.append(x)
x.block_until_ready()
jax.profiler.save_device_memory_profile(f"memory{i}.prof")
anotherfunc()
如果我们简单地在执行结束时可视化设备内存配置文件(memory9.prof
),可能不会很明显地看出 anotherfunc
循环的每次迭代都会累积更多的设备内存分配:
pprof --web memory9.prof
在 afunction
内部的大量但固定的分配主导了性能分析,但不会随时间增长。
通过使用 pprof
的 --diff_base
功能 来可视化内存使用在循环迭代中的变化,我们可以识别出程序的内存使用为何会随时间增加:
pprof --web --diff_base memory1.prof memory9.prof
可视化显示,内存增长可以归因于 anotherfunc
内部对 normal
的调用。