jax.debug.callback

目录

jax.debug.callback#

jax.debug.callback(callback, *args, ordered=False, **kwargs)[源代码]#

调用一个可分阶段的 Python 回调。

更多解释,请参见 外部回调

jax.debug.callback 允许你传入一个可以在分阶段 JAX 程序内部调用的 Python 函数。jax.debug.callback 遵循现有的 JAX 转换 操作语义,因此对副作用是不可知的。这意味着在存在高阶原语和转换的情况下,副作用可能会被丢弃、复制或潜在地重新排序。

我们希望这种行为,因为我们希望 jax.debug.callback 是“无害的”,即我们希望这些原语在尽可能少地改变 JAX 计算的同时,尽可能多地揭示关于它们的信息,例如计算的哪些部分被复制或丢弃。

参数:
  • callback (Callable[..., None]) – 一个返回 None 的 Python 可调用对象。

  • *args (Any) – 回调函数的位置参数。

  • ordered (bool) – 一个仅限关键字的参数,用于指示暂存计算是否会强制执行此回调相对于其他有序回调的顺序。

  • **kwargs (Any) – 回调的关键字参数。

返回:

返回类型:

None

参见