jax.export.Exported#

class jax.export.Exported(fun_name, in_tree, in_avals, out_tree, out_avals, in_shardings_hlo, out_shardings_hlo, nr_devices, platforms, ordered_effects, unordered_effects, disabled_safety_checks, mlir_module_serialized, calling_convention_version, module_kept_var_idx, uses_global_constants, _get_vjp)[源代码][源代码]#

一个降低到 StableHLO 的 JAX 函数。

参数:
  • fun_name (str)

  • in_tree (tree_util.PyTreeDef)

  • in_avals (tuple[core.ShapedArray, ...])

  • out_tree (tree_util.PyTreeDef)

  • out_avals (tuple[core.ShapedArray, ...])

  • in_shardings_hlo (tuple[HloSharding | None, ...])

  • out_shardings_hlo (tuple[HloSharding | None, ...])

  • nr_devices (int)

  • platforms (tuple[str, ...])

  • ordered_effects (tuple[effects.Effect, ...])

  • unordered_effects (tuple[effects.Effect, ...])

  • disabled_safety_checks (Sequence[DisabledSafetyCheck])

  • mlir_module_serialized (bytes)

  • calling_convention_version (int)

  • module_kept_var_idx (tuple[int, ...])

  • uses_global_constants (bool)

  • _get_vjp (Callable[[Exported], Exported] | None)

fun_name#

导出函数的名称,用于错误消息。

类型:

str

in_tree#

描述降低后的 JAX 函数元组 (args, kwargs) 的 PyTreeDef。实际的降低过程不依赖于 in_tree,但可以使用相同的参数结构调用导出的函数。

类型:

tree_util.PyTreeDef

in_avals#

输入抽象值的扁平元组。形状中可能包含维度表达式。

类型:

tuple[core.ShapedArray, …]

out_tree#

描述降低后的 JAX 函数结果的 PyTreeDef。

类型:

tree_util.PyTreeDef

out_avals#

输出抽象值的扁平元组。形状中可能包含维度表达式,维度变量可能包含在 in_avals 中。

类型:

tuple[core.ShapedArray, …]

in_shardings_hlo#

扁平化的输入分片,一个与 in_avals 长度相同的序列。None 表示未指定的分片。请注意,这些不包括网格或网格中使用的实际设备。有关如何将这些转换为可与 JAX API 一起使用的分片规范,请参见 in_shardings_jax

类型:

tuple[HloSharding | None, …]

out_shardings_hlo#

扁平化的输出分片,一个与 out_avals 长度相同的序列。None 表示未指定的分片。请注意,这些不包括网格或网格中使用的实际设备。请参阅 out_shardings_jax 以了解如何将这些转换为可与 JAX API 一起使用的分片规范。

类型:

tuple[HloSharding | None, …]

nr_devices#

该模块的设备数量已减少。

类型:

整数

platforms#

一个包含应导出函数的平台的元组。JAX 中的平台集合是开放的;用户可以添加平台。JAX 内置平台有:’tpu’, ‘cpu’, ‘cuda’, ‘rocm’。参见 https://jax.readthedocs.io/en/latest/export/export.html#cross-platform-and-multi-platform-export

类型:

tuple[str, …]

ordered_effects#

序列化模块中存在的顺序效应。这是从序列化版本9开始存在的。有关存在顺序效应时的调用约定,请参见 https://jax.readthedocs.io/en/latest/export/export.html#module-calling-convention

类型:

tuple[effects.Effect, …]

unordered_effects#

序列化模块中存在的无序效果。这是从序列化版本9开始存在的。

类型:

tuple[effects.Effect, …]

mlir_module_serialized#

序列化的低级 VHLO 模块。

类型:

字节

calling_convention_version#

导出模块的调用约定的版本号。更多版本详情请参见 https://jax.readthedocs.io/en/latest/export/export.html#calling-convention-versions

类型:

整数

module_kept_var_idx#

in_avals 中必须传递给模块的参数的排序索引。其他参数已被丢弃,因为它们未被使用。

类型:

tuple[int, …]

uses_global_constants#

是否 mlir_module_serialized 使用了形状多态性或多平台导出。这可能是因为 in_avals 包含了维度变量,或者由于导出模块的内部调用具有维度变量或平台索引参数。此类模块在XLA编译之前需要进行形状细化。

类型:

布尔

disabled_safety_checks#

在导出时已禁用的安全检查描述符列表。参见 DisabledSafetyCheck 的文档字符串。

类型:

Sequence[DisabledSafetyCheck]

_get_vjp#

一个可选函数,它接受当前导出的函数并返回导出的 VJP 函数。VJP 函数接受一个扁平的参数列表,从原始参数开始,然后是每个原始输出的一个余切参数。它返回一个元组,其中包含与扁平原始输入相对应的余切。

类型:

Callable[[导出], 导出] | None

查看 [关于 mlir_module 调用约定的描述](https://jax.readthedocs.io/en/latest/export/export.html#module-calling-convention)。

__init__(fun_name, in_tree, in_avals, out_tree, out_avals, in_shardings_hlo, out_shardings_hlo, nr_devices, platforms, ordered_effects, unordered_effects, disabled_safety_checks, mlir_module_serialized, calling_convention_version, module_kept_var_idx, uses_global_constants, _get_vjp)#
参数:
  • fun_name (str)

  • in_tree (tree_util.PyTreeDef)

  • in_avals (tuple[core.ShapedArray, ...])

  • out_tree (tree_util.PyTreeDef)

  • out_avals (tuple[core.ShapedArray, ...])

  • in_shardings_hlo (tuple[HloSharding | None, ...])

  • out_shardings_hlo (tuple[HloSharding | None, ...])

  • nr_devices (int)

  • platforms (tuple[str, ...])

  • ordered_effects (tuple[effects.Effect, ...])

  • unordered_effects (tuple[effects.Effect, ...])

  • disabled_safety_checks (Sequence[DisabledSafetyCheck])

  • mlir_module_serialized (bytes)

  • calling_convention_version (int)

  • module_kept_var_idx (tuple[int, ...])

  • uses_global_constants (bool)

  • _get_vjp (Callable[[Exported], Exported] | None)

返回类型:

None

方法

__init__(fun_name, in_tree, in_avals, ...)

call(*args, **kwargs)

has_vjp()

返回此导出是否支持 VJP。

in_shardings_jax(mesh)

创建与 self.in_shardings_hlo 对应的 Shardings。

mlir_module()

out_shardings_jax(mesh)

创建与 self.out_shardings_hlo 对应的 Shardings。

serialize([vjp_order])

序列化一个导出的对象。

vjp()

获取导出的 VJP。

属性

in_shardings

lowering_platforms

已弃用。

mlir_module_serialization_version

已弃用。

out_shardings

uses_shape_polymorphism

已弃用。

fun_name

in_tree

in_avals

out_tree

out_avals

in_shardings_hlo

out_shardings_hlo

nr_devices

platforms

ordered_effects

unordered_effects

disabled_safety_checks

mlir_module_serialized

calling_convention_version

module_kept_var_idx

uses_global_constants