分析设备内存#

备注

2023年5月更新:我们推荐使用 Tensorboard 分析 进行设备内存分析。在完成分析后,打开 Tensorboard 分析器的 memory_viewer 标签以获取更详细且易于理解的设备内存使用情况。

JAX 设备内存分析器允许我们探索 JAX 程序如何以及为何使用 GPU 或 TPU 内存。例如,它可以用于:

  • 找出在给定时间哪些数组和可执行文件在GPU内存中,或者

  • 追踪内存泄漏。

安装#

JAX 设备内存分析器发出的输出可以使用 pprof(google/pprof)进行解释。首先按照其 安装说明 安装 pprof。在撰写本文时,安装 pprof 需要先安装 1.16 版本以上的 GoGraphviz,然后运行

go install github.com/google/pprof@latest

这将安装 pprof$GOPATH/bin/pprof,其中 GOPATH 默认为 ~/go

备注

来自 google/pprofpprof 版本与作为 gperftools 包一部分分发的同名旧工具不同。gperftools 版本的 pprof 将无法与 JAX 一起使用。

理解 JAX 程序如何使用 GPU 或 TPU 内存#

设备内存分析器的一个常见用途是找出为什么 JAX 程序使用了大量的 GPU 或 TPU 内存,例如在尝试调试内存不足问题时。

要捕获设备内存配置文件到磁盘,请使用 jax.profiler.save_device_memory_profile()。例如,考虑以下 Python 程序:

import jax
import jax.numpy as jnp
import jax.profiler

def func1(x):
  return jnp.tile(x, 10) * 0.5

def func2(x):
  y = func1(x)
  return y, jnp.tile(x, 10) + 1

x = jax.random.normal(jax.random.key(42), (1000, 1000))
y, z = func2(x)

z.block_until_ready()

jax.profiler.save_device_memory_profile("memory.prof")

如果我们先运行上面的程序,然后再执行

pprof --web memory.prof

pprof 打开一个包含以下调用图格式的设备内存概要可视化的网页浏览器:

设备内存分析示例

调用图是对Python堆栈在分配每个活动缓冲区时的可视化。例如,在这种情况下,可视化显示func2及其被调用者负责分配了76.30MB,其中38.15MB是在从func1调用func2时分配的。有关如何解释调用图可视化的更多信息,请参阅pprof文档

使用 jax.jit() 编译的函数对设备内存分析器是不透明的。也就是说,在 jit 编译的函数内部分配的任何内存都将归因于整个函数。

在示例中,调用 block_until_ready() 是为了确保在收集设备内存配置文件之前 func2 完成。更多详情请参见 异步分发

调试内存泄漏#

我们还可以使用 JAX 设备内存分析器,通过使用 pprof 来可视化在不同时间点获取的两个设备内存配置文件之间的内存使用变化,从而追踪内存泄漏。例如,考虑以下程序,该程序将 JAX 数组累积到一个不断增长的 Python 列表中。

import jax
import jax.numpy as jnp
import jax.profiler

def afunction():
  return jax.random.normal(jax.random.key(77), (1000000,))

z = afunction()

def anotherfunc():
  arrays = []
  for i in range(1, 10):
    x = jax.random.normal(jax.random.key(42), (i, 10000))
    arrays.append(x)
    x.block_until_ready()
    jax.profiler.save_device_memory_profile(f"memory{i}.prof")

anotherfunc()

如果我们简单地在执行结束时可视化设备内存配置文件(memory9.prof),可能不会很明显地看出 anotherfunc 循环的每次迭代都会累积更多的设备内存分配:

pprof --web memory9.prof

执行结束时的设备内存概况

afunction 内部的大量但固定的分配主导了性能分析,但不会随时间增长。

通过使用 pprof--diff_base 功能 来可视化内存使用在循环迭代中的变化,我们可以识别出程序的内存使用为何会随时间增加:

pprof --web --diff_base memory1.prof memory9.prof

执行结束时的设备内存概况

可视化显示,内存增长可以归因于 anotherfunc 内部对 normal 的调用。