jax.named_call#
- jax.named_call(fun, *, name=None)[源代码][源代码]#
在分阶段执行JAX计算时,为函数添加用户指定的名称。
当为即时编译到XLA(或其他后端如TensorFlow)准备计算时,JAX运行您的Python程序,但默认情况下不会保留与该程序相关的任何函数名称或其他元数据。这会使调试准备好的(和/或编译的)程序表示变得复杂,因为每个操作执行时可用的上下文信息有限。
named_call 告诉 JAX 将给定的函数作为具有特定名称的子计算进行分阶段处理。当分阶段处理的程序使用 XLA 编译时,这些命名的子计算会被保留,并在 TensorFlow Profiler 等调试工具中显示,例如 TensorBoard。当使用
experimental.jax2tf.convert()
将 JAX 程序分阶段处理到 TensorFlow 时,名称也会被保留。- 参数:
fun (F) – 要包装的函数。这可以是任何可调用对象。
name (str | None) – 可选。用于命名在名称范围内创建的所有子计算的前缀。如果未指定,则使用 fun.__name__。
- 返回:
一个在 name_scope 中包装的 fun 版本。
- 返回类型:
F