torch.mps.event 的源代码
import torch
[docs]class Event:
r"""MPS事件的包装器。
MPS事件是同步标记,可用于监控设备的进度,精确测量时间,并同步MPS流。
参数:
enable_timing (bool, 可选): 指示事件是否应测量时间
(默认值: ``False``)
"""
def __init__(self, enable_timing=False):
self.__eventId = torch._C._mps_acquireEvent(enable_timing)
def __del__(self):
# 检查torch._C是否已经被销毁
if hasattr(torch._C, "_mps_releaseEvent") and self.__eventId > 0:
torch._C._mps_releaseEvent(self.__eventId)
[docs] def record(self):
r"""在默认流中记录事件。"""
torch._C._mps_recordEvent(self.__eventId)
[docs] def wait(self):
r"""使提交到默认流的所有未来工作等待此事件。"""
torch._C._mps_waitForEvent(self.__eventId)
[docs] def query(self):
r"""如果事件当前捕获的所有工作都已完成,则返回True。"""
return torch._C._mps_queryEvent(self.__eventId)
[docs] def synchronize(self):
r"""等待此事件中当前捕获的所有工作完成。
这会阻止CPU线程继续执行,直到事件完成。
"""
torch._C._mps_synchronizeEvent(self.__eventId)
[docs] def elapsed_time(self, end_event):
r"""返回事件记录后到end_event记录前的毫秒数。
"""
return torch._C._mps_elapsedTimeOfEvents(self.__eventId, end_event.__eventId)