jax.experimental.host_callback.barrier_wait

jax.experimental.host_callback.barrier_wait#

jax.experimental.host_callback.barrier_wait(logging_name=None)[源代码]#

阻塞调用线程,直到所有当前的外流数据处理完毕。

等待所有设备上已经运行的计算的所有回调被接收并由Python回调处理。如果在处理回调时发生异常,则引发CallbackException。

这是通过将一个特殊的抽头计算排入我们正在监听输出反馈的所有设备来实现的。一旦所有这些抽头计算完成,我们就从 barrier_wait 返回。

注意:如果任何设备正忙且无法接受新的计算任务,这将导致死锁。

参数:

logging_name (str | None) – 一个可选的字符串,将在本次调用的日志语句中使用。请参阅模块文档中的 调试

更多详情请参阅 jax.experimental.host_callback 模块文档。