jax.experimental.pallas.debug_print

目录

jax.experimental.pallas.debug_print#

jax.experimental.pallas.debug_print(fmt, *args)[源代码][源代码]#

从 Pallas 内核内部打印标量值。

参数:
  • fmt (str) – 一个包含在输出中的格式字符串。格式字符串的限制取决于后端:* 在GPU上,当使用Triton时,fmt 不能包含任何占位符({...}),因为它总是在任何值之前打印。* 在GPU上,当使用实验性的Mosaic GPU后端时,fmt 必须包含每个要打印的值的占位符。不支持格式说明符和转换。* 在TPU上,如果``fmt`` 包含占位符,所有值必须是32位整数。如果没有占位符,值将在格式字符串之后打印。

  • *args (jax.typing.ArrayLike) – 要打印的标量值。