jax.xla_computation

jax.xla_computation#

jax.xla_computation(fun, static_argnums=(), axis_env=None, in_parts=None, out_parts=None, backend=None, tuple_args=False, instantiate_const_outputs=None, return_shape=False, donate_argnums=())[源代码][源代码]#

创建一个函数,该函数在给定示例参数的情况下生成其XLA计算。

警告

自 JAX v0.4.30 起,此功能已被弃用,并将在未来的 JAX 版本中移除。您可以用 提前降低和编译 API 替换它;例如,jax.xla_computation(fn)(*args) 可以用 jax.jit(fn).lower(*args).compiler_ir('hlo') 替换。更多示例请参见 JAX 0.4.30 变更日志

参数:
  • fun (Callable) – 用于形成XLA计算的函数。

  • static_argnums (int | Iterable[int]) – 参见 jax.jit() 的文档字符串。

  • axis_env (Sequence[tuple[AxisName, int]] | None) – 可选的,一系列的成对序列,其中第一个元素是轴名称,第二个元素是一个正整数,表示具有该名称的映射轴的大小。当降低涉及并行通信集体的函数时,此参数很有用,它指定了由 jax.pmap() 应用设置的轴名称/大小环境。请参见下面的示例。

  • in_parts – 可选的,如何对 fun 的每个参数进行分区或复制。这用于指定分区的 XLA 计算,更多信息请参见 sharded_jit

  • out_parts – 可选的,如何对 fun 的每个输出进行分区或复制。这用于指定分区的 XLA 计算,更多信息请参见 sharded_jit

  • backend (str | None) – 这是一个实验性功能,API 可能会发生变化。可选,表示 XLA 后端的字符串:'cpu''gpu''tpu'

  • tuple_args (bool) – 可选的布尔值,默认为 False。如果为 True,生成的 XLA 计算将有一个单一的元组参数,该参数被解包为指定的函数参数。如果为 None,当参数数量超过 100 时,将启用元组化,因为某些平台对参数数量有限制。

  • instantiate_const_outputs (bool | None) – 已弃用的参数,没有任何作用。

  • return_shape (bool) – 可选的布尔值,默认为 False。如果为 True,被包装的函数返回一个对,其中第一个元素是 XLA 计算,第二个元素是一个与 fun 输出具有相同结构的 pytree,其中叶子是具有 shapedtype 属性的对象,表示相应输出叶子的类型。

  • donate_argnums (int | Iterable[int]) – 指定哪些参数被“捐赠”给计算。如果你在计算完成后不再需要这些参数,那么捐赠它们是安全的。在某些情况下,XLA 可以利用捐赠的缓冲区来减少执行计算所需的内存量,例如回收你的一个输入缓冲区来存储结果。你不应该重用捐赠给计算的缓冲区,如果你尝试这样做,JAX 会抛出错误。

返回:

fun 的包装版本,当应用于示例参数时返回一个构建的 XLA 计算(参见 xla_client.py),从中可以使用 as_hlo_textas_serialized_hlo_module_protoas_hlo_dot_graph 等方法提取未优化的 XLA HLO 计算的表示。如果参数 return_shapeTrue,则包装函数返回一个对,其中第一个元素是 XLA 计算,第二个元素是一个 pytree,表示 fun 输出的结构、形状、dtypes 和命名形状。具体示例参数并不总是必要的。对于那些未由 static_argnums 指示的参数,任何具有 shapedtype 属性的对象都是可接受的(命名元组除外,它们被视为 Python 容器)。

返回类型:

Callable

例如:

>>> import jax
>>>
>>> def f(x): return jax.numpy.sin(jax.numpy.cos(x))
>>> c = jax.xla_computation(f)(3.)  
>>> print(c.as_hlo_text())  
HloModule xla_computation_f.6

ENTRY xla_computation_f.6 {
  constant.2 = pred[] constant(false)
  parameter.1 = f32[] parameter(0)
  cosine.3 = f32[] cosine(parameter.1)
  sine.4 = f32[] sine(cosine.3)
  ROOT tuple.5 = (f32[]) tuple(sine.4)
}

或者,上面的 c 赋值可以写成:

>>> import types
>>> scalar = types.SimpleNamespace(shape=(), dtype=np.dtype(np.float32))
>>> c = jax.xla_computation(f)(scalar)  

以下是一个涉及并行集体和轴名称的示例:

>>> def f(x): return x - jax.lax.psum(x, 'i')
>>> c = jax.xla_computation(f, axis_env=[('i', 4)])(2)  
>>> print(c.as_hlo_text())  
HloModule jaxpr_computation.9
primitive_computation.3 {
  parameter.4 = s32[] parameter(0)
  parameter.5 = s32[] parameter(1)
  ROOT add.6 = s32[] add(parameter.4, parameter.5)
}
ENTRY jaxpr_computation.9 {
  tuple.1 = () tuple()
  parameter.2 = s32[] parameter(0)
  all-reduce.7 = s32[] all-reduce(parameter.2), replica_groups={{0,1,2,3}}, to_apply=primitive_computation.3
  ROOT subtract.8 = s32[] subtract(parameter.2, all-reduce.7)
}

注意生成的 replica_groups。以下是一个生成更有趣的 replica_groups 的示例:

>>> from jax import lax
>>> def g(x):
...   rowsum = lax.psum(x, 'i')
...   colsum = lax.psum(x, 'j')
...   allsum = lax.psum(x, ('i', 'j'))
...   return rowsum, colsum, allsum
...
>>> axis_env = [('i', 4), ('j', 2)]
>>> c = jax.xla_computation(g, axis_env=axis_env)(5.)  
>>> print(c.as_hlo_text())  
HloModule jaxpr_computation__1.19
[removed uninteresting text here]
ENTRY jaxpr_computation__1.19 {
  tuple.1 = () tuple()
  parameter.2 = f32[] parameter(0)
  all-reduce.7 = f32[] all-reduce(parameter.2), replica_groups={{0,2,4,6},{1,3,5,7}}, to_apply=primitive_computation__1.3
  all-reduce.12 = f32[] all-reduce(parameter.2), replica_groups={{0,1},{2,3},{4,5},{6,7}}, to_apply=primitive_computation__1.8
  all-reduce.17 = f32[] all-reduce(parameter.2), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=primitive_computation__1.13
  ROOT tuple.18 = (f32[], f32[], f32[]) tuple(all-reduce.7, all-reduce.12, all-reduce.17)
}