JAX 调试标志#

JAX 提供了标志和上下文管理器,使得捕获错误更加容易。

jax_debug_nans 配置选项和上下文管理器#

摘要: 启用 jax_debug_nans 标志以自动检测在 jax.jit 编译代码中何时产生 NaN(但在 jax.pmapjax.pjit 编译代码中不会检测)。

jax_debug_nans 是一个 JAX 标志,当启用时,会在检测到 NaN 时自动引发错误。它对 JIT 编译有特殊处理——当从 JIT 编译的函数中检测到 NaN 输出时,该函数会急切地重新运行(即不进行编译),并在产生 NaN 的特定原语处抛出错误。

用法#

如果你想追踪函数或梯度中NaN出现的位置,可以通过以下方式开启NaN检查器:

  • 设置环境变量 JAX_DEBUG_NANS=True

  • 在主文件顶部附近添加 jax.config.update("jax_debug_nans", True)

  • 在你的主文件中添加 jax.config.parse_flags_with_absl(),然后使用命令行标志设置选项,例如 --jax_debug_nans=True

示例#

import jax
jax.config.update("jax_debug_nans", True)

def f(x, y):
  return x / y
jax.jit(f)(0., 0.)  # ==> raises FloatingPointError exception!

jax_debug_nans 的优缺点#

优势#
  • 易于应用

  • 精确检测 NaN 产生的位置

  • 抛出一个标准的 Python 异常,并且与 PDB 事后调试兼容。

限制#
  • 不兼容 jax.pmapjax.pjit

  • 急切地重新运行函数可能会很慢

  • 误报错误(例如,故意创建的 NaNs)

jax_disable_jit 配置选项和上下文管理器#

总结: 启用 jax_disable_jit 标志以禁用即时编译,从而可以使用传统的 Python 调试工具,如 printpdb

jax_disable_jit 是一个 JAX 标志,启用时,会禁用 JAX 中的 JIT 编译(包括在控制流函数如 jax.lax.condjax.lax.scan 中)。

用法#

你可以通过以下方式禁用JIT编译:

  • 设置环境变量 JAX_DISABLE_JIT=True

  • 在主文件顶部附近添加 jax.config.update("jax_disable_jit", True)

  • 在你的主文件中添加 jax.config.parse_flags_with_absl(),然后使用命令行标志设置选项,例如 --jax_disable_jit=True

示例#

import jax
jax.config.update("jax_disable_jit", True)

def f(x):
  y = jnp.log(x)
  if jnp.isnan(y):
    breakpoint()
  return y
jax.jit(f)(-2.)  # ==> Enters PDB breakpoint!

jax_disable_jit 的优缺点#

优势#
  • 易于应用

  • 启用使用 Python 内置的 breakpointprint

  • 抛出标准的 Python 异常,并且与 PDB 事后调试兼容。

限制#
  • 不兼容 jax.pmapjax.pjit

  • 在没有JIT编译的情况下运行函数可能会很慢