调试简介#

本节向您介绍一组内置的 JAX 调试方法 — jax.debug.print()jax.debug.breakpoint()jax.debug.callback() — 您可以将这些方法与各种 JAX 变换一起使用。

让我们从 jax.debug.print() 开始。

jax.debug.print 用于简单检查#

这里有一个经验法则:

回顾 即时编译,当使用 jax.jit() 转换函数时,Python 代码会在抽象跟踪器代替数组的情况下执行。因此,Python 的 print() 函数只会打印这个跟踪器值:

import jax
import jax.numpy as jnp

@jax.jit
def f(x):
  print("print(x) ->", x)
  y = jnp.sin(x)
  print("print(y) ->", y)
  return y

result = f(2.)
print(x) -> Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
print(y) -> Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>

Python 的 print 在跟踪时执行,此时运行时值还不存在。如果你想打印实际的运行时值,你可以使用 jax.debug.print()

@jax.jit
def f(x):
  jax.debug.print("jax.debug.print(x) -> {x}", x=x)
  y = jnp.sin(x)
  jax.debug.print("jax.debug.print(y) -> {y}", y=y)
  return y

result = f(2.)
jax.debug.print(x) -> 2.0
jax.debug.print(y) -> 0.9092974066734314

同样地,在 jax.vmap() 中,使用 Python 的 print 只会打印追踪器;要打印被映射的值,请使用 jax.debug.print()

def f(x):
  jax.debug.print("jax.debug.print(x) -> {}", x)
  y = jnp.sin(x)
  jax.debug.print("jax.debug.print(y) -> {}", y)
  return y

xs = jnp.arange(3.)

result = jax.vmap(f)(xs)
jax.debug.print(x) -> 0.0
jax.debug.print(x) -> 1.0
jax.debug.print(x) -> 2.0
jax.debug.print(y) -> 0.0
jax.debug.print(y) -> 0.8414709568023682
jax.debug.print(y) -> 0.9092974066734314

这里是使用 jax.lax.map() 的结果,这是一个顺序映射而非向量化映射:

result = jax.lax.map(f, xs)
jax.debug.print(x) -> 0.0
jax.debug.print(y) -> 0.0
jax.debug.print(x) -> 1.0
jax.debug.print(y) -> 0.8414709568023682
jax.debug.print(x) -> 2.0
jax.debug.print(y) -> 0.9092974066734314

注意顺序是不同的,因为 jax.vmap()jax.lax.map() 以不同的方式计算相同的结果。在调试时,评估顺序的细节正是你可能需要检查的内容。

下面是一个使用 jax.grad() 的示例,其中 jax.debug.print() 仅打印前向传递。在这种情况下,其行为类似于 Python 的 print(),但如果您在调用期间应用 jax.jit(),则其行为是一致的。

def f(x):
  jax.debug.print("jax.debug.print(x) -> {}", x)
  return x ** 2

result = jax.grad(f)(1.)
jax.debug.print(x) -> 1.0

有时,当参数不相互依赖时,使用 jax.debug.print() 调用可能会在使用 JAX 转换进行分阶段处理时以不同的顺序打印它们。如果你需要原始顺序,例如先 x: ... 然后 y: ...,请添加 ordered=True 参数。

例如:

@jax.jit
def f(x, y):
  jax.debug.print("jax.debug.print(x) -> {}", x, ordered=True)
  jax.debug.print("jax.debug.print(y) -> {}", y, ordered=True)
  return x + y

f(1, 2)
jax.debug.print(x) -> 1
jax.debug.print(y) -> 2
Array(3, dtype=int32, weak_type=True)

要了解更多关于 jax.debug.print() 及其 Sharp Bits 的信息,请参阅 高级调试

jax.debug.breakpoint 用于 pdb 风格的调试#

总结: 使用 jax.debug.breakpoint() 暂停 JAX 程序的执行以检查值。

要在调试期间在编译的 JAX 程序的某些点暂停,您可以使用 jax.debug.breakpoint()。提示类似于 Python pdb,它允许您检查调用堆栈中的值。实际上,jax.debug.breakpoint()jax.debug.callback() 的应用,它捕获有关调用堆栈的信息。

要在 breakpoint 调试会话期间打印所有可用命令,请使用 help 命令。(完整的调试器命令、高级功能、其优点和局限性在 高级调试 中介绍。)

以下是一个调试器会话可能的样子:

@jax.jit
def f(x):
  y, z = jnp.sin(x), jnp.cos(x)
  jax.debug.breakpoint()
  return y * z
f(2.) # ==> Pauses during execution

JAX 调试器

对于依赖于值的断点,您可以使用运行时条件,如 jax.lax.cond()

def breakpoint_if_nonfinite(x):
  is_finite = jnp.isfinite(x).all()
  def true_fn(x):
    pass
  def false_fn(x):
    jax.debug.breakpoint()
  jax.lax.cond(is_finite, true_fn, false_fn, x)

@jax.jit
def f(x, y):
  z = x / y
  breakpoint_if_nonfinite(z)
  return z

f(2., 1.) # ==> No breakpoint
Array(2., dtype=float32, weak_type=True)
f(2., 0.) # ==> Pauses during execution

jax.debug.callback 用于在调试期间进行更多控制#

jax.debug.printjax.debug.breakpoint 都是通过更灵活的 jax.debug.callback 实现的,这提供了对通过 Python 回调执行的主机端逻辑的更大控制。它与 jax.jitjax.vmapjax.grad 和其他变换兼容(更多信息请参阅 外部回调 中的 回调的类型 表格)。

例如:

import logging

def log_value(x):
  logging.warning(f'Logged value: {x}')

@jax.jit
def f(x):
  jax.debug.callback(log_value, x)
  return x

f(1.0);
WARNING:root:Logged value: 1.0

此回调与其他转换兼容,包括 jax.vmap()jax.grad()

x = jnp.arange(5.0)
jax.vmap(f)(x);
WARNING:root:Logged value: 0.0
WARNING:root:Logged value: 1.0
WARNING:root:Logged value: 2.0
WARNING:root:Logged value: 3.0
WARNING:root:Logged value: 4.0
jax.grad(f)(1.0);
WARNING:root:Logged value: 1.0

这使得 jax.debug.callback() 在通用调试中非常有用。

你可以在 外部回调 中了解更多关于 jax.debug.callback() 和其他类型的 JAX 回调。

下一步#

查看 高级调试 以了解更多关于在 JAX 中调试的信息。