编译后的打印和断点#

The jax.debug 包提供了一些有用的工具,用于检查编译函数内部的值。

使用 jax.debug.print 和其他调试回调进行调试#

摘要: 使用 jax.debug.print() 在编译(例如 jax.jitjax.pmap 装饰的)函数中将追踪的数组值打印到标准输出:

import jax
import jax.numpy as jnp

@jax.jit
def f(x):
  jax.debug.print("🤯 {x} 🤯", x=x)
  y = jnp.sin(x)
  jax.debug.print("🤯 {y} 🤯", y=y)
  return y

f(2.)
# Prints:
# 🤯 2.0 🤯
# 🤯 0.9092974662780762 🤯

通过一些变换,比如 jax.gradjax.vmap,你可以使用 Python 的内置 print 函数来打印数值。但 printjax.jitjax.pmap 中不起作用,因为这些变换会延迟数值评估。因此,请使用 jax.debug.print 代替!

从语义上讲,jax.debug.print 大致等同于以下 Python 函数

def debug.print(fmt: str, *args: PyTree[Array], **kwargs: PyTree[Array]) -> None:
  print(fmt.format(*args, **kwargs))

除了它可以被JAX分阶段处理和转换之外。更多详情请参阅 API参考

注意 fmt 不能是一个 f-字符串,因为 f-字符串会立即格式化,而对于 jax.debug.print,我们希望延迟格式化直到稍后。

何时使用“debug”打印?#

你应该在 JAX 变换(如 jitvmap 等)中使用 jax.debug.print 来打印动态(即追踪的)数组值。对于打印静态值(如数组形状或数据类型),你可以使用普通的 Python print 语句。

为什么是“debug”打印?#

在调试的名义下,jax.debug.print 可以揭示关于计算如何评估的信息:

xs = jnp.arange(3.)

def f(x):
  jax.debug.print("x: {}", x)
  y = jnp.sin(x)
  jax.debug.print("y: {}", y)
  return y
jax.vmap(f)(xs)
# Prints: x: 0.0
#         x: 1.0
#         x: 2.0
#         y: 0.0
#         y: 0.841471
#         y: 0.9092974
jax.lax.map(f, xs)
# Prints: x: 0.0
#         y: 0.0
#         x: 1.0
#         y: 0.841471
#         x: 2.0
#         y: 0.9092974

请注意,打印结果的顺序不同!

通过揭示这些内部工作机制,jax.debug.print 的输出不遵循 JAX 通常的语义保证,例如 jax.vmap(f)(xs)jax.lax.map(f, xs) 计算相同的东西(以不同的方式)。然而,这些评估顺序的细节正是我们在调试时可能想要看到的!

因此,在调试时使用 jax.debug.print,而在语义保证重要时不使用。

更多 jax.debug.print 的示例#

除了使用 jitvmap 的上述示例外,这里还有一些需要记住的示例。

jax.pmap 下打印#

当使用 jax.pmap 时,jax.debug.print 的输出可能会被重新排序!

xs = jnp.arange(2.)

def f(x):
  jax.debug.print("x: {}", x)
  return x
jax.pmap(f)(xs)
# Prints: x: 1.0
#         x: 0.0
# OR
# Prints: x: 1.0
#         x: 0.0

jax.grad 下的打印#

jax.grad 下,jax.debug.print 只会在前向传递时打印:

def f(x):
  jax.debug.print("x: {}", x)
  return x * 2.

jax.grad(f)(1.)
# Prints: x: 1.0

这种行为类似于 Python 内置的 printjax.grad 下的工作方式。但通过在这里使用 jax.debug.print,即使调用者应用了 jax.jit,行为也是相同的。

要在反向传播时打印,只需使用 jax.custom_vjp

@jax.custom_vjp
def print_grad(x):
  return x

def print_grad_fwd(x):
  return x, None

def print_grad_bwd(_, x_grad):
  jax.debug.print("x_grad: {}", x_grad)
  return (x_grad,)

print_grad.defvjp(print_grad_fwd, print_grad_bwd)


def f(x):
  x = print_grad(x)
  return x * 2.
jax.grad(f)(1.)
# Prints: x_grad: 2.0

在其他转换中打印#

jax.debug.print 在其他转换中也能工作,例如 pjit

通过 jax.debug.callback 获得更多控制#

事实上,jax.debug.printjax.debug.callback 的一个便捷封装,可以直接用于更精细的字符串格式控制,甚至是输出类型。

从语义上讲,jax.debug.callback 大致相当于以下 Python 函数

def callback(fun: Callable, *args: PyTree[Array], **kwargs: PyTree[Array]) -> None:
  fun(*args, **kwargs)
  return None

jax.debug.print 一样,这些回调应仅用于调试输出,如打印或绘图。打印和绘图基本上是无害的,但如果您将其用于其他任何用途,其行为在转换下可能会让您感到意外。例如,使用 jax.debug.callback 进行计时操作是不安全的,因为回调可能会被重新排序和异步执行(见下文)。

尖锐部分#

像大多数 JAX API 一样,jax.debug.print 如果你不小心使用,可能会让你受伤。

打印结果的排序#

当对 jax.debug.print 的不同调用涉及互不依赖的参数时,它们在分阶段输出时可能会被重新排序,例如通过 jax.jit 进行:

@jax.jit
def f(x, y):
  jax.debug.print("x: {}", x)
  jax.debug.print("y: {}", y)
  return x + y

f(2., 3.)
# Prints: x: 2.0
#         y: 3.0
# OR
# Prints: y: 3.0
#         x: 2.0

为什么?在底层,编译器获得了一个分阶段计算的功能表示,其中Python函数的命令顺序丢失,只保留了数据依赖性。对于功能纯正的代码,这种变化对用户是不可见的,但在存在打印等副作用的情况下,这种变化是显而易见的。

为了保留 jax.debug.print 在 Python 函数中编写的原始顺序,您可以使用 jax.debug.print(..., ordered=True),这将确保打印的相对顺序被保留。但是,在涉及并行性的 jax.pmap 和其他 JAX 转换中使用 ordered=True 会引发错误,因为在并行执行下无法保证顺序。

异步回调#

根据后端的不同,jax.debug.print 的执行可能是异步的,即不在主程序线程中进行。这意味着即使 JAX 函数已经返回了一个值,值仍可能被打印到屏幕上。

@jax.jit
def f(x):
  jax.debug.print("x: {}", x)
  return x
f(2.).block_until_ready()
# <do something else>
# Prints: x: 2.

要在函数中阻塞 jax.debug.print,可以调用 jax.effects_barrier(),它将等待函数中所有剩余的副作用完成:

@jax.jit
def f(x):
  jax.debug.print("x: {}", x)
  return x
f(2.).block_until_ready()
jax.effects_barrier()
# Prints: x: 2.
# <do something else>

性能影响#

不必要的物化#

虽然 jax.debug.print 的设计旨在具有最小的性能开销,但它可能会干扰编译器优化,并可能影响您的 JAX 程序的内存占用。

def f(w, b, x):
  logits = w.dot(x) + b
  jax.debug.print("logits: {}", logits)
  return jax.nn.relu(logits)

在这个例子中,我们在一个线性层和激活函数之间打印中间值。像XLA这样的编译器可以执行融合优化,这可能会避免在内存中具体化 logits。但是当我们对 logits 使用 jax.debug.print 时,我们正在强制这些中间值被具体化,这可能会减慢程序速度并增加内存使用。

此外,当使用 jax.debug.printjax.pjit 时,会发生全局同步,这将使值在单个设备上具体化。

回调开销#

jax.debug.print 本质上会导致加速器与其主机之间的通信。底层机制因后端而异(例如 GPU 与 TPU),但在所有情况下,我们都需要将打印的值从设备复制到主机。在 CPU 情况下,这种开销较小。

此外,当使用 jax.debug.printjax.pjit 时,会发生全局同步,这会增加一些开销。

jax.debug.print 的优缺点#

优势#

  • 打印调试简单直观

  • jax.debug.callback 可以用于其他无害的副作用

限制#

  • 添加打印语句是一个手动过程

  • 可能影响性能

使用 jax.debug.breakpoint() 进行交互式检查#

总结: 使用 jax.debug.breakpoint() 来暂停你的 JAX 程序的执行,以便检查值:

@jax.jit
def f(x):
  y, z = jnp.sin(x), jnp.cos(x)
  jax.debug.breakpoint()
  return y * z
f(2.) # ==> Pauses during execution!

JAX 调试器

jax.debug.breakpoint() 实际上只是 jax.debug.callback(...) 的一个应用,它捕获了调用堆栈的信息。因此,它具有与 jax.debug.print 相同的转换行为(例如,vmap-ing jax.debug.breakpoint() 会在映射的轴上展开它)。

用法#

在编译的 JAX 函数中调用 jax.debug.breakpoint() 将在命中断点时暂停您的程序。您将看到一个类似于 pdb 的提示符,允许您检查调用堆栈中的值。与 pdb 不同,您将无法逐步执行,但可以继续执行。

调试器命令:

  • help - 打印出可用命令

  • p - 计算一个表达式并打印其结果

  • pp - 计算表达式并将其结果漂亮地打印出来

  • u(p) - 向上移动一个堆栈帧

  • d(own) - 向下移动一个堆栈帧

  • w(here)/bt - 打印出回溯信息

  • l(ist) - 打印代码上下文

  • c(ont(inue)) - 恢复程序的执行

  • q(uit)/exit - 退出程序(在TPU上无效)

示例#

使用 jax.lax.cond#

当与 jax.lax.cond 结合使用时,调试器可以成为检测 naninf 的有用工具。

def breakpoint_if_nonfinite(x):
  is_finite = jnp.isfinite(x).all()
  def true_fn(x):
    pass
  def false_fn(x):
    jax.debug.breakpoint()
  lax.cond(is_finite, true_fn, false_fn, x)

@jax.jit
def f(x, y):
  z = x / y
  breakpoint_if_nonfinite(z)
  return z
f(2., 0.) # ==> Pauses during execution!

尖锐部分#

因为 jax.debug.breakpoint 只是 jax.debug.callback 的一个应用,它具有与 jax.debug.print 相同的 sharp bits as jax.debug.print,还有一些额外的注意事项:

  • jax.debug.breakpointjax.debug.print 实现 更多 的中间结果,因为它强制实现了调用栈中的所有值。

  • jax.debug.breakpointjax.debug.print 有更多的运行时开销,因为它可能需要将 JAX 程序中的所有中间值从设备复制到主机。

jax.debug.breakpoint() 的优缺点#

优势#

  • 简单、直观且(某种程度上)标准

  • 可以同时检查许多值,包括调用栈的上下文。

限制#

  • 可能需要使用多个断点来精确定位错误的来源

  • 具体化许多中间结果