jax.xla_computation#
- jax.xla_computation(fun, static_argnums=(), axis_env=None, in_parts=None, out_parts=None, backend=None, tuple_args=False, instantiate_const_outputs=None, return_shape=False, donate_argnums=())[源代码][源代码]#
创建一个函数,该函数在给定示例参数的情况下生成其XLA计算。
警告
自 JAX v0.4.30 起,此功能已被弃用,并将在未来的 JAX 版本中移除。您可以用 提前降低和编译 API 替换它;例如,
jax.xla_computation(fn)(*args)
可以用jax.jit(fn).lower(*args).compiler_ir('hlo')
替换。更多示例请参见 JAX 0.4.30 变更日志。- 参数:
fun (Callable) – 用于形成XLA计算的函数。
axis_env (Sequence[tuple[AxisName, int]] | None) – 可选的,一系列的成对序列,其中第一个元素是轴名称,第二个元素是一个正整数,表示具有该名称的映射轴的大小。当降低涉及并行通信集体的函数时,此参数很有用,它指定了由
jax.pmap()
应用设置的轴名称/大小环境。请参见下面的示例。in_parts – 可选的,如何对
fun
的每个参数进行分区或复制。这用于指定分区的 XLA 计算,更多信息请参见sharded_jit
。out_parts – 可选的,如何对
fun
的每个输出进行分区或复制。这用于指定分区的 XLA 计算,更多信息请参见sharded_jit
。backend (str | None) – 这是一个实验性功能,API 可能会发生变化。可选,表示 XLA 后端的字符串:
'cpu'
、'gpu'
或'tpu'
。tuple_args (bool) – 可选的布尔值,默认为
False
。如果为True
,生成的 XLA 计算将有一个单一的元组参数,该参数被解包为指定的函数参数。如果为 None,当参数数量超过 100 时,将启用元组化,因为某些平台对参数数量有限制。instantiate_const_outputs (bool | None) – 已弃用的参数,没有任何作用。
return_shape (bool) – 可选的布尔值,默认为
False
。如果为True
,被包装的函数返回一个对,其中第一个元素是 XLA 计算,第二个元素是一个与fun
输出具有相同结构的 pytree,其中叶子是具有shape
和dtype
属性的对象,表示相应输出叶子的类型。donate_argnums (int | Iterable[int]) – 指定哪些参数被“捐赠”给计算。如果你在计算完成后不再需要这些参数,那么捐赠它们是安全的。在某些情况下,XLA 可以利用捐赠的缓冲区来减少执行计算所需的内存量,例如回收你的一个输入缓冲区来存储结果。你不应该重用捐赠给计算的缓冲区,如果你尝试这样做,JAX 会抛出错误。
- 返回:
fun
的包装版本,当应用于示例参数时返回一个构建的 XLA 计算(参见 xla_client.py),从中可以使用as_hlo_text
、as_serialized_hlo_module_proto
和as_hlo_dot_graph
等方法提取未优化的 XLA HLO 计算的表示。如果参数return_shape
为True
,则包装函数返回一个对,其中第一个元素是 XLA 计算,第二个元素是一个 pytree,表示fun
输出的结构、形状、dtypes 和命名形状。具体示例参数并不总是必要的。对于那些未由static_argnums
指示的参数,任何具有shape
和dtype
属性的对象都是可接受的(命名元组除外,它们被视为 Python 容器)。- 返回类型:
Callable
例如:
>>> import jax >>> >>> def f(x): return jax.numpy.sin(jax.numpy.cos(x)) >>> c = jax.xla_computation(f)(3.) >>> print(c.as_hlo_text()) HloModule xla_computation_f.6 ENTRY xla_computation_f.6 { constant.2 = pred[] constant(false) parameter.1 = f32[] parameter(0) cosine.3 = f32[] cosine(parameter.1) sine.4 = f32[] sine(cosine.3) ROOT tuple.5 = (f32[]) tuple(sine.4) }
或者,上面的
c
赋值可以写成:>>> import types >>> scalar = types.SimpleNamespace(shape=(), dtype=np.dtype(np.float32)) >>> c = jax.xla_computation(f)(scalar)
以下是一个涉及并行集体和轴名称的示例:
>>> def f(x): return x - jax.lax.psum(x, 'i') >>> c = jax.xla_computation(f, axis_env=[('i', 4)])(2) >>> print(c.as_hlo_text()) HloModule jaxpr_computation.9 primitive_computation.3 { parameter.4 = s32[] parameter(0) parameter.5 = s32[] parameter(1) ROOT add.6 = s32[] add(parameter.4, parameter.5) } ENTRY jaxpr_computation.9 { tuple.1 = () tuple() parameter.2 = s32[] parameter(0) all-reduce.7 = s32[] all-reduce(parameter.2), replica_groups={{0,1,2,3}}, to_apply=primitive_computation.3 ROOT subtract.8 = s32[] subtract(parameter.2, all-reduce.7) }
注意生成的
replica_groups
。以下是一个生成更有趣的replica_groups
的示例:>>> from jax import lax >>> def g(x): ... rowsum = lax.psum(x, 'i') ... colsum = lax.psum(x, 'j') ... allsum = lax.psum(x, ('i', 'j')) ... return rowsum, colsum, allsum ... >>> axis_env = [('i', 4), ('j', 2)] >>> c = jax.xla_computation(g, axis_env=axis_env)(5.) >>> print(c.as_hlo_text()) HloModule jaxpr_computation__1.19 [removed uninteresting text here] ENTRY jaxpr_computation__1.19 { tuple.1 = () tuple() parameter.2 = f32[] parameter(0) all-reduce.7 = f32[] all-reduce(parameter.2), replica_groups={{0,2,4,6},{1,3,5,7}}, to_apply=primitive_computation__1.3 all-reduce.12 = f32[] all-reduce(parameter.2), replica_groups={{0,1},{2,3},{4,5},{6,7}}, to_apply=primitive_computation__1.8 all-reduce.17 = f32[] all-reduce(parameter.2), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=primitive_computation__1.13 ROOT tuple.18 = (f32[], f32[], f32[]) tuple(all-reduce.7, all-reduce.12, all-reduce.17) }