编译后的打印和断点#
The jax.debug
包提供了一些有用的工具,用于检查编译函数内部的值。
使用 jax.debug.print
和其他调试回调进行调试#
摘要: 使用 jax.debug.print()
在编译(例如 jax.jit
或 jax.pmap
装饰的)函数中将追踪的数组值打印到标准输出:
import jax
import jax.numpy as jnp
@jax.jit
def f(x):
jax.debug.print("🤯 {x} 🤯", x=x)
y = jnp.sin(x)
jax.debug.print("🤯 {y} 🤯", y=y)
return y
f(2.)
# Prints:
# 🤯 2.0 🤯
# 🤯 0.9092974662780762 🤯
通过一些变换,比如 jax.grad
和 jax.vmap
,你可以使用 Python 的内置 print
函数来打印数值。但 print
在 jax.jit
或 jax.pmap
中不起作用,因为这些变换会延迟数值评估。因此,请使用 jax.debug.print
代替!
从语义上讲,jax.debug.print
大致等同于以下 Python 函数
def debug.print(fmt: str, *args: PyTree[Array], **kwargs: PyTree[Array]) -> None:
print(fmt.format(*args, **kwargs))
除了它可以被JAX分阶段处理和转换之外。更多详情请参阅 API参考
。
注意 fmt
不能是一个 f-字符串,因为 f-字符串会立即格式化,而对于 jax.debug.print
,我们希望延迟格式化直到稍后。
何时使用“debug”打印?#
你应该在 JAX 变换(如 jit
、vmap
等)中使用 jax.debug.print
来打印动态(即追踪的)数组值。对于打印静态值(如数组形状或数据类型),你可以使用普通的 Python print
语句。
为什么是“debug”打印?#
在调试的名义下,jax.debug.print
可以揭示关于计算如何评估的信息:
xs = jnp.arange(3.)
def f(x):
jax.debug.print("x: {}", x)
y = jnp.sin(x)
jax.debug.print("y: {}", y)
return y
jax.vmap(f)(xs)
# Prints: x: 0.0
# x: 1.0
# x: 2.0
# y: 0.0
# y: 0.841471
# y: 0.9092974
jax.lax.map(f, xs)
# Prints: x: 0.0
# y: 0.0
# x: 1.0
# y: 0.841471
# x: 2.0
# y: 0.9092974
请注意,打印结果的顺序不同!
通过揭示这些内部工作机制,jax.debug.print
的输出不遵循 JAX 通常的语义保证,例如 jax.vmap(f)(xs)
和 jax.lax.map(f, xs)
计算相同的东西(以不同的方式)。然而,这些评估顺序的细节正是我们在调试时可能想要看到的!
因此,在调试时使用 jax.debug.print
,而在语义保证重要时不使用。
更多 jax.debug.print
的示例#
除了使用 jit
和 vmap
的上述示例外,这里还有一些需要记住的示例。
在 jax.pmap
下打印#
当使用 jax.pmap
时,jax.debug.print
的输出可能会被重新排序!
xs = jnp.arange(2.)
def f(x):
jax.debug.print("x: {}", x)
return x
jax.pmap(f)(xs)
# Prints: x: 1.0
# x: 0.0
# OR
# Prints: x: 1.0
# x: 0.0
在 jax.grad
下的打印#
在 jax.grad
下,jax.debug.print
只会在前向传递时打印:
def f(x):
jax.debug.print("x: {}", x)
return x * 2.
jax.grad(f)(1.)
# Prints: x: 1.0
这种行为类似于 Python 内置的 print
在 jax.grad
下的工作方式。但通过在这里使用 jax.debug.print
,即使调用者应用了 jax.jit
,行为也是相同的。
要在反向传播时打印,只需使用 jax.custom_vjp
:
@jax.custom_vjp
def print_grad(x):
return x
def print_grad_fwd(x):
return x, None
def print_grad_bwd(_, x_grad):
jax.debug.print("x_grad: {}", x_grad)
return (x_grad,)
print_grad.defvjp(print_grad_fwd, print_grad_bwd)
def f(x):
x = print_grad(x)
return x * 2.
jax.grad(f)(1.)
# Prints: x_grad: 2.0
在其他转换中打印#
jax.debug.print
在其他转换中也能工作,例如 pjit
。
通过 jax.debug.callback
获得更多控制#
事实上,jax.debug.print
是 jax.debug.callback
的一个便捷封装,可以直接用于更精细的字符串格式控制,甚至是输出类型。
从语义上讲,jax.debug.callback
大致相当于以下 Python 函数
def callback(fun: Callable, *args: PyTree[Array], **kwargs: PyTree[Array]) -> None:
fun(*args, **kwargs)
return None
与 jax.debug.print
一样,这些回调应仅用于调试输出,如打印或绘图。打印和绘图基本上是无害的,但如果您将其用于其他任何用途,其行为在转换下可能会让您感到意外。例如,使用 jax.debug.callback
进行计时操作是不安全的,因为回调可能会被重新排序和异步执行(见下文)。
jax.debug.print
的优缺点#
优势#
打印调试简单直观
jax.debug.callback
可以用于其他无害的副作用
限制#
添加打印语句是一个手动过程
可能影响性能
使用 jax.debug.breakpoint()
进行交互式检查#
总结: 使用 jax.debug.breakpoint()
来暂停你的 JAX 程序的执行,以便检查值:
@jax.jit
def f(x):
y, z = jnp.sin(x), jnp.cos(x)
jax.debug.breakpoint()
return y * z
f(2.) # ==> Pauses during execution!
jax.debug.breakpoint()
实际上只是 jax.debug.callback(...)
的一个应用,它捕获了调用堆栈的信息。因此,它具有与 jax.debug.print
相同的转换行为(例如,vmap
-ing jax.debug.breakpoint()
会在映射的轴上展开它)。
用法#
在编译的 JAX 函数中调用 jax.debug.breakpoint()
将在命中断点时暂停您的程序。您将看到一个类似于 pdb
的提示符,允许您检查调用堆栈中的值。与 pdb
不同,您将无法逐步执行,但可以继续执行。
调试器命令:
help
- 打印出可用命令p
- 计算一个表达式并打印其结果pp
- 计算表达式并将其结果漂亮地打印出来u(p)
- 向上移动一个堆栈帧d(own)
- 向下移动一个堆栈帧w(here)/bt
- 打印出回溯信息l(ist)
- 打印代码上下文c(ont(inue))
- 恢复程序的执行q(uit)/exit
- 退出程序(在TPU上无效)
示例#
使用 jax.lax.cond
#
当与 jax.lax.cond
结合使用时,调试器可以成为检测 nan
或 inf
的有用工具。
def breakpoint_if_nonfinite(x):
is_finite = jnp.isfinite(x).all()
def true_fn(x):
pass
def false_fn(x):
jax.debug.breakpoint()
lax.cond(is_finite, true_fn, false_fn, x)
@jax.jit
def f(x, y):
z = x / y
breakpoint_if_nonfinite(z)
return z
f(2., 0.) # ==> Pauses during execution!
尖锐部分#
因为 jax.debug.breakpoint
只是 jax.debug.callback
的一个应用,它具有与 jax.debug.print
相同的 sharp bits as jax.debug.print
,还有一些额外的注意事项:
jax.debug.breakpoint
比jax.debug.print
实现 更多 的中间结果,因为它强制实现了调用栈中的所有值。jax.debug.breakpoint
比jax.debug.print
有更多的运行时开销,因为它可能需要将 JAX 程序中的所有中间值从设备复制到主机。
jax.debug.breakpoint()
的优缺点#
优势#
简单、直观且(某种程度上)标准
可以同时检查许多值,包括调用栈的上下文。
限制#
可能需要使用多个断点来精确定位错误的来源
具体化许多中间结果