jax.profiler.annotate_函数#
- jax.profiler.annotate_function(func, name=None, **decorator_kwargs)[源代码][源代码]#
为函数执行生成跟踪事件的装饰器。
例如:
>>> @jax.profiler.annotate_function ... def f(x): ... return jnp.dot(x, x.T).block_until_ready() >>> >>> result = f(jnp.ones((1000, 1000)))
如果函数执行发生在进程被 TensorBoard 追踪期间,这将导致在追踪时间线上显示一个“f”事件。
可以通过
functools.partial()
将参数传递给装饰器。>>> from functools import partial
>>> @partial(jax.profiler.annotate_function, name="event_name") ... def f(x): ... return jnp.dot(x, x.T).block_until_ready()
>>> result = f(jnp.ones((1000, 1000)))
- 参数:
func (Callable)
name (str | None)