jax.stages 模块#

编译执行过程各阶段的接口。

JAX 转换,如 jax.jitjax.pmap,它们在执行时即时编译,也支持一种常见的提前显式降级和编译方式。此模块定义了表示此过程各个阶段的类型。

更多信息,请参阅 AOT 演练

#

class jax.stages.Wrapped(*args, **kwargs)[源代码][源代码]#

一个准备被追踪、降低和编译的函数。

此协议反映了诸如 jax.jit 等函数的输出。调用它会导致即时(JIT)降低、编译和执行。它也可以在编译前显式降低,并在执行前编译结果。

__call__(*args, **kwargs)[源代码][源代码]#

执行包装的函数,根据需要降低和编译。

lower(*args, **kwargs)[源代码][源代码]#

为给定的参数显式地降低此函数。

一个降低的函数被移出Python并翻译成编译器的输入语言,可能以依赖后端的方式进行。它已准备好进行编译,但尚未编译。

返回:

一个表示降低的 Lowered 实例。

返回类型:

Lowered

trace(*args, **kwargs)[源代码][源代码]#

显式地追踪给定参数的此函数。

一个被跟踪的函数被从Python中分离出来并转换为一个jaxpr。它已经准备好进行降级,但尚未降级。

返回:

一个表示追踪的 Traced 实例。

返回类型:

Traced

class jax.stages.Lowered(lowering, args_info, out_tree, no_kwargs=False)[源代码][源代码]#

将函数特化为特定参数类型和值的过程。

降低(lowering)是一种准备编译的计算。此类携带一个降低计算以及稍后编译和执行所需的其他信息。它还提供了一个通用API,用于查询JAX各种降低路径(jit()、:func:`~jax.pmap`等)中降低计算的属性。

参数:
  • lowering (XlaLowering)

  • args_info (Any)

  • out_tree (tree_util.PyTreeDef)

  • no_kwargs (bool)

as_text(dialect=None)[源代码][源代码]#

此降低的人类可读文本表示。

用于可视化和调试目的。这不需要是一个有效或可靠的序列化。它直接传递给外部调用者。

参数:

dialect (str | None) – 可选字符串,指定一个降低方言(例如“stablehlo”)

返回类型:

str

compile(compiler_options=None)[源代码][源代码]#

编译,返回一个相应的 Compiled 实例。

参数:

compiler_options (CompilerOptions | None)

返回类型:

Compiled

compiler_ir(dialect=None)[源代码][源代码]#

此降低的任意对象表示。

用于调试目的。这不是一个有效且可靠的序列化。输出在不同调用之间没有一致性保证。

如果不可用则返回 None ,例如基于后端、编译器或运行时。

参数:

dialect (str | None) – 可选字符串,指定一个降低方言(例如“stablehlo”)

返回类型:

Any | None

cost_analysis()[源代码][源代码]#

执行成本估算的总结。

用于可视化和调试目的。此对象输出的数据结构简单,易于打印或序列化(例如,嵌套的字典、列表和元组,其叶子节点为数值)。然而,其结构可以是任意的:它可能在不同版本的 JAX 和 jaxlib 之间不一致,甚至在不同调用之间也不一致。

如果不可用则返回 None ,例如基于后端、编译器或运行时。

返回类型:

Any | None

property in_tree: tree_util.PyTreeDef[源代码]#

对 (位置参数, 关键字参数) 的树结构。

class jax.stages.Compiled(executable, args_info, out_tree, no_kwargs=False)[源代码][源代码]#

针对特定类型/值编译的函数的表示形式。

编译后的计算与一个可执行文件以及执行它所需的其他信息相关联。它还提供了一个通用API,用于查询JAX各种编译路径和后端中编译计算的属性。

参数:
  • args_info (Any)

  • out_tree (tree_util.PyTreeDef)

__call__(*args, **kwargs)[源代码][源代码]#

作为函数调用自身。

as_text()[源代码][源代码]#

此可执行文件的人类可读文本表示。

用于可视化和调试目的。这不是一个有效且可靠的序列化。

如果不可用则返回 None ,例如基于后端、编译器或运行时。

返回类型:

str | None

cost_analysis()[源代码][源代码]#

执行成本估算的总结。

用于可视化和调试目的。此对象输出的数据结构简单,易于打印或序列化(例如,嵌套的字典、列表和元组,其叶子节点为数值)。然而,其结构可以是任意的:它可能在不同版本的 JAX 和 jaxlib 之间不一致,甚至在不同调用之间也不一致。

如果不可用则返回 None ,例如基于后端、编译器或运行时。

返回类型:

Any | None

property in_tree: tree_util.PyTreeDef[源代码]#

对 (位置参数, 关键字参数) 的树结构。

memory_analysis()[源代码][源代码]#

估计内存需求的总结。

用于可视化和调试目的。此对象输出的数据结构简单,易于打印或序列化(例如,嵌套的字典、列表和元组,其叶子节点为数值)。然而,其结构可以是任意的:它可能在不同版本的 JAX 和 jaxlib 之间不一致,甚至在不同调用之间也不一致。

如果不可用则返回 None ,例如基于后端、编译器或运行时。

返回类型:

Any | None

runtime_executable()[源代码][源代码]#

此可执行文件的任意对象表示。

用于调试目的。这不是有效的或可靠的序列化。输出的结果在不同调用之间没有一致性保证。

如果不可用则返回 None ,例如基于后端、编译器或运行时。

返回类型:

Any | None