jax.profiler.StepTraceAnnotation#
- class jax.profiler.StepTraceAnnotation(name, **kwargs)[源代码][源代码]#
在分析器中生成步骤跟踪事件的上下文管理器。
步骤跟踪事件跨越上下文所包围的代码的持续时间。分析器将为每个步骤跟踪事件提供性能分析。
例如,它可以用来标记训练步骤,并使分析器能够提供每个步骤的性能分析:
>>> while global_step < NUM_STEPS: ... with jax.profiler.StepTraceAnnotation("train", step_num=global_step): ... train_step() ... global_step += 1
如果在进程被 TensorBoard 追踪时事件发生,这将导致追踪时间线上显示一个“train xx”事件。此外,如果使用加速器,设备追踪时间线上也会显示一个“train xx”事件。请注意,可以将“step_num”设置为关键字参数,以将全局步数传递给分析器。
- 参数:
name (str)
方法
__init__
(self, arg0, /, **kwargs)属性
is_enabled
set_metadata