jax.debug.print

目录

jax.debug.print#

jax.debug.print(fmt, *args, ordered=False, **kwargs)[源代码]#

打印值并在分阶段输出的 JAX 函数中工作。

此函数 适用于 f-字符串,因为格式化被延迟了。因此,不要写 jax.debug.print(f"hello {bar}"),而是写 jax.debug.print("hello {bar}", bar=bar)

这个函数是一个围绕 jax.debug.callback() 的薄便利包装。其实现基本上是:

def debug_print(fmt: str, *args, **kwargs):
  jax.debug.callback(
      lambda *args, **kwargs: print(fmt.format(*args, **kwargs)),
      *args, **kwargs)

直接调用 jax.debug.callback() 可能会有用,而不是使用这个便利的包装器。例如,要在日志中获取调试打印,您可以结合使用 jax.debug.callback()logging.log

参数:
  • fmt (str) – 一个格式化字符串,例如 "hello {x}",将用于格式化输入参数,类似于 str.format。请参阅 Python 文档中的 字符串格式化格式化字符串语法

  • *args – 要格式化的位置参数列表,就像传递给 fmt.format 一样。

  • ordered (bool) – 一个仅关键字的参数,用于指示是否对与此 jax.debug.print 相关的其他有序 jax.debug.print 调用强制执行排序。

  • **kwargs – 要格式化的附加关键字参数,就像传递给 fmt.format 一样。

返回类型:

None