jax.debug.print#
- jax.debug.print(fmt, *args, ordered=False, **kwargs)[源代码]#
打印值并在分阶段输出的 JAX 函数中工作。
此函数 不 适用于 f-字符串,因为格式化被延迟了。因此,不要写
jax.debug.print(f"hello {bar}")
,而是写jax.debug.print("hello {bar}", bar=bar)
。这个函数是一个围绕
jax.debug.callback()
的薄便利包装。其实现基本上是:def debug_print(fmt: str, *args, **kwargs): jax.debug.callback( lambda *args, **kwargs: print(fmt.format(*args, **kwargs)), *args, **kwargs)
直接调用
jax.debug.callback()
可能会有用,而不是使用这个便利的包装器。例如,要在日志中获取调试打印,您可以结合使用jax.debug.callback()
和logging.log
。