jax.experimental.pallas.debug_print#
- jax.experimental.pallas.debug_print(fmt, *args)[源代码][源代码]#
从 Pallas 内核内部打印标量值。
- 参数:
fmt (str) – 一个包含在输出中的格式字符串。格式字符串的限制取决于后端:* 在GPU上,当使用Triton时,
fmt
不能包含任何占位符({...}
),因为它总是在任何值之前打印。* 在GPU上,当使用实验性的Mosaic GPU后端时,fmt
必须包含每个要打印的值的占位符。不支持格式说明符和转换。* 在TPU上,如果``fmt`` 包含占位符,所有值必须是32位整数。如果没有占位符,值将在格式字符串之后打印。*args (jax.typing.ArrayLike) – 要打印的标量值。