jax.named_call

目录

jax.named_call#

jax.named_call(fun, *, name=None)[源代码][源代码]#

在分阶段执行JAX计算时,为函数添加用户指定的名称。

当为即时编译到XLA(或其他后端如TensorFlow)准备计算时,JAX运行您的Python程序,但默认情况下不会保留与该程序相关的任何函数名称或其他元数据。这会使调试准备好的(和/或编译的)程序表示变得复杂,因为每个操作执行时可用的上下文信息有限。

named_call 告诉 JAX 将给定的函数作为具有特定名称的子计算进行分阶段处理。当分阶段处理的程序使用 XLA 编译时,这些命名的子计算会被保留,并在 TensorFlow Profiler 等调试工具中显示,例如 TensorBoard。当使用 experimental.jax2tf.convert() 将 JAX 程序分阶段处理到 TensorFlow 时,名称也会被保留。

参数:
  • fun (F) – 要包装的函数。这可以是任何可调用对象。

  • name (str | None) – 可选。用于命名在名称范围内创建的所有子计算的前缀。如果未指定,则使用 fun.__name__。

返回:

一个在 name_scope 中包装的 fun 版本。

返回类型:

F