调试运行时值#
你是否遇到了梯度爆炸?是否因为NaN而咬牙切齿?只是想查看计算中的中间值?查看以下JAX调试工具!本页有总结,你可以点击底部的“阅读更多”链接以了解更多。
目录:
使用 jax.debug
进行交互式检查#
完整指南 在此
总结: 使用 jax.debug.print()
在 jax.jit
-、jax.pmap
- 和 pjit
-装饰的函数中打印值到标准输出,并使用 jax.debug.breakpoint()
暂停编译函数的执行以检查调用堆栈中的值:
import jax
import jax.numpy as jnp
@jax.jit
def f(x):
jax.debug.print("🤯 {x} 🤯", x=x)
y = jnp.sin(x)
jax.debug.breakpoint()
jax.debug.print("🤯 {y} 🤯", y=y)
return y
f(2.)
# Prints:
# 🤯 2.0 🤯
# Enters breakpoint to inspect values!
# 🤯 0.9092974662780762 🤯
阅读更多。
使用 jax.experimental.checkify
进行功能性错误检查#
完整指南 这里
总结: Checkify 允许你为 JAX 代码添加可 jit
的运行时错误检查(例如越界索引)。使用 checkify.checkify
变换与类似断言的 checkify.check
函数来为 JAX 代码添加运行时检查:
from jax.experimental import checkify
import jax
import jax.numpy as jnp
def f(x, i):
checkify.check(i >= 0, "index needs to be non-negative!")
y = x[i]
z = jnp.sin(y)
return z
jittable_f = checkify.checkify(f)
err, z = jax.jit(jittable_f)(jnp.ones((5,)), -1)
print(err.get())
# >> index needs to be non-negative! (check failed at <...>:6 (f))
你也可以使用 checkify 来自动添加常见检查:
errors = checkify.user_checks | checkify.index_checks | checkify.float_checks
checked_f = checkify.checkify(f, errors=errors)
err, z = checked_f(jnp.ones((5,)), 100)
err.throw()
# ValueError: out-of-bounds indexing at <..>:7 (f)
err, z = checked_f(jnp.ones((5,)), -1)
err.throw()
# ValueError: index needs to be non-negative! (check failed at <…>:6 (f))
err, z = checked_f(jnp.array([jnp.inf, 1]), 0)
err.throw()
# ValueError: nan generated by primitive sin at <...>:8 (f)
阅读更多。
使用 JAX 的调试标志抛出 Python 错误#
完整指南 在此
总结: 启用 jax_debug_nans
标志以自动检测在 jax.jit
编译代码中何时产生 NaN(但在 jax.pmap
或 jax.pjit
编译代码中不适用),并启用 jax_disable_jit
标志以禁用 JIT 编译,从而可以使用传统的 Python 调试工具,如 print
和 pdb
。
import jax
jax.config.update("jax_debug_nans", True)
def f(x, y):
return x / y
jax.jit(f)(0., 0.) # ==> raises FloatingPointError exception!
阅读更多。