jax.debug.breakpoint

目录

jax.debug.breakpoint#

jax.debug.breakpoint(*, backend=None, filter_frames=True, num_frames=None, ordered=False, token=None, **kwargs)[源代码][源代码]#

在程序的某一点进入断点。

参数:
  • backend (str | None) – 要使用的调试器后端。默认情况下,选择优先级最高的调试器,如果没有其他注册的调试器,则回退到CLI调试器。

  • filter_frames (bool) – 是否从回溯中过滤掉 JAX 内部的堆栈帧。由于一些库,如 Flax,也使用了 JAX 的堆栈帧过滤系统,此选项还可以影响是否过滤来自库的堆栈帧。

  • num_frames (int | None) – 在交互式调试器中可用于检查的当前堆栈帧之上的帧数。

  • ordered (bool) – 一个仅关键字的参数,用于指示是否对与此 jax.debug.breakpoint 相关的其他有序 jax.debug.breakpointjax.debug.print 调用强制执行排序。

  • token – 仅关键字参数;ordered 的替代方案。如果使用,则应传递一个 JAX 数组(或 JAX 数组的 pytree),并且在其值计算后将运行断点。这会原封不动地返回,并应传递回计算中。如果返回值在后续计算中未使用,则整个计算将被修剪,并且此断点将不会运行。

返回:

如果传递了 token,则其值将原样返回。否则,返回 None