jax.profiler.TraceAnnotation

jax.profiler.TraceAnnotation#

class jax.profiler.TraceAnnotation[源代码][源代码]#

在分析器中生成跟踪事件的上下文管理器。

跟踪事件的时间跨度涵盖了上下文所包围的代码的持续时间。

例如:

>>> x = jnp.ones((1000, 1000))
>>> with jax.profiler.TraceAnnotation("my_label"):
...   result = jnp.dot(x, x.T).block_until_ready()

如果事件发生时进程正在被跟踪,这将导致“my_label”事件出现在跟踪时间线上。

__init__(self, arg0: str, /, **kwargs) None#

属性

is_enabled

set_metadata