提前降低和编译#

JAX 提供了几种变换,如 jax.jitjax.pmap,返回一个在加速器或CPU上编译并运行的函数。正如JIT缩写所示,所有编译都是为执行_即时_进行的。

某些情况下需要 ahead-of-time (AOT) 编译。当你希望在执行前完全编译,或者你想控制编译过程中不同部分的执行时间时,JAX 提供了一些选项供你选择。

首先,让我们回顾一下编译的阶段。假设 f 是一个由 jax.jit() 输出的函数/可调用对象,例如 f = jax.jit(F),其中 F 是某个输入的可调用对象。当它被调用时,例如 f(x, y),其中 xy 是数组,JAX 按以下顺序执行:

  1. 阶段输出 原始 Python 可调用对象 F 的一个专门版本到内部表示。这种专门化反映了 F 对从参数 xy 的属性(通常是它们的形状和元素类型)推断出的输入类型的限制。

  2. 降低 这种专门的、分阶段计算到 XLA 编译器的输入语言,StableHLO。

  3. 编译 降低后的 HLO 程序以生成针对目标设备(CPU、GPU 或 TPU)的优化可执行文件。

  4. 执行 编译后的可执行文件,使用数组 xy 作为参数。

JAX 的 AOT API 让你直接控制步骤 #2、#3 和 #4(但 不是 #1),以及在此过程中的一些其他功能。例如:

>>> import jax

>>> def f(x, y): return 2 * x + y
>>> x, y = 3, 4

>>> lowered = jax.jit(f).lower(x, y)

>>> # Print lowered HLO
>>> print(lowered.as_text())
module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}, %arg1: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %c = stablehlo.constant dense<2> : tensor<i32>
    %0 = stablehlo.multiply %c, %arg0 : tensor<i32>
    %1 = stablehlo.add %0, %arg1 : tensor<i32>
    return %1 : tensor<i32>
  }
}

>>> compiled = lowered.compile()

>>> # Query for cost analysis, print FLOP estimate
>>> compiled.cost_analysis()[0]['flops']
2.0

>>> # Execute the compiled function!
>>> compiled(x, y)
Array(10, dtype=int32, weak_type=True)

请注意,降低的对象只能在它们被降低的同一进程中使用。对于导出用例,请参阅 导出 API。

有关降级和编译函数提供的功能的更多详细信息,请参阅 jax.stages 文档。

在上面的 jax.jit 替代方案中,你也可以 lower(...) jax.pmap() 的结果,以及 pjitxmap(分别来自 jax.experimental.pjitjax.experimental.maps)。在每种情况下,你都可以类似地 compile() 结果。

所有 jit 的可选参数——例如 static_argnums——在相应的降低、编译和执行过程中都会被尊重。同样的情况也适用于 pmappjitxmap

在上面的例子中,我们可以将 lower 的参数替换为任何具有 shapedtype 属性的对象:

>>> i32_scalar = jax.ShapeDtypeStruct((), jnp.dtype('int32'))
>>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x, y)
Array(10, dtype=int32)

更一般地,lower 只需要其参数在结构上提供 JAX 进行特化和降低所需的信息。对于上述典型的数组参数,这意味着 shapedtype 字段。对于静态参数,相比之下,JAX 需要实际的数组值(更多内容请参见 下文)。

使用与AOT编译函数不兼容的参数调用该函数会引发错误:

>>> x_1d = y_1d = jnp.arange(3)
>>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x_1d, y_1d)  
...
Traceback (most recent call last):
TypeError: Argument types differ from the types for which this computation was compiled. The mismatches are:
Argument 'x' compiled with int32[] and called with int32[3]
Argument 'y' compiled with int32[] and called with int32[3]

>>> x_f = y_f = jnp.float32(72.)
>>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x_f, y_f)  
...
Traceback (most recent call last):
TypeError: Argument types differ from the types for which this computation was compiled. The mismatches are:
Argument 'x' compiled with int32[] and called with float32[]
Argument 'y' compiled with int32[] and called with float32[]

与此相关,AOT编译的函数不能被JAX的即时变换jax.jitjax.grad()jax.vmap() 所变换。

使用静态参数降低#

使用静态参数降低强调了传递给 jax.jit 的选项、传递给 lower 的参数以及调用生成的编译函数所需的参数之间的交互。继续我们上面的例子:

>>> lowered_with_x = jax.jit(f, static_argnums=0).lower(7, 8)

>>> # Lowered HLO, specialized to the *value* of the first argument (7)
>>> print(lowered_with_x.as_text())
module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %c = stablehlo.constant dense<14> : tensor<i32>
    %0 = stablehlo.add %c, %arg0 : tensor<i32>
    return %0 : tensor<i32>
  }
}

>>> lowered_with_x.compile()(5)
Array(19, dtype=int32, weak_type=True)

lower 的结果不能直接序列化以用于不同的进程。有关此目的的其他API,请参见 导出和序列化

注意,lower 这里像往常一样接受两个参数,但后续编译的函数只接受剩余的非静态第二个参数。静态的第一个参数(值为7)在降低时间被视为常量,并内置于降低的计算中,在此过程中它可能与其他常量折叠。在这种情况下,它乘以2的计算被简化,结果为常量14。

尽管上述 lower 的第二个参数可以用一个空形状/dtype结构替换,但静态的第一个参数必须是具体值。否则,降低操作将出错:

>>> jax.jit(f, static_argnums=0).lower(i32_scalar, i32_scalar)  
Traceback (most recent call last):
TypeError: unsupported operand type(s) for *: 'int' and 'ShapeDtypeStruct'

>>> jax.jit(f, static_argnums=0).lower(10, i32_scalar).compile()(5)
Array(25, dtype=int32)

AOT 编译的函数不能被转换#

编译后的函数是针对特定的一组参数“类型”进行优化的,例如在我们运行的示例中,特定形状和元素类型的数组。从JAX的内部角度来看,诸如 jax.vmap() 之类的变换会以一种使编译后的类型签名无效的方式改变函数的类型签名。作为一种策略,JAX 简单地禁止编译后的函数参与这些变换。示例:

>>> def g(x):
...   assert x.shape == (3, 2)
...   return x @ jnp.ones(2)

>>> def make_z(*shape):
...   return jnp.arange(np.prod(shape)).reshape(shape)

>>> z, zs = make_z(3, 2), make_z(4, 3, 2)

>>> g_jit = jax.jit(g)
>>> g_aot = jax.jit(g).lower(z).compile()

>>> jax.vmap(g_jit)(zs)
Array([[ 1.,  5.,  9.],
       [13., 17., 21.],
       [25., 29., 33.],
       [37., 41., 45.]], dtype=float32)

>>> jax.vmap(g_aot)(zs)  
Traceback (most recent call last):
TypeError: Cannot apply JAX transformations to a function lowered and compiled for a particular signature. Detected argument of Tracer type <class 'jax._src.interpreters.batching.BatchTracer'>

g_aot 参与自动微分(例如 jax.grad())时,会引发类似的错误。为了保持一致性,即使 jit 不会实质性地修改其参数的类型签名,jax.jit 的转换也是不允许的。

调试信息和分析,如果有的话#

除了主要的AOT功能(分离和显式的降低、编译和执行),JAX的各种AOT阶段还提供了一些额外的功能,以帮助调试和收集编译器反馈。

例如,如上面的初始示例所示,降低的函数通常提供文本表示。编译的函数也是如此,并且还提供来自编译器的成本和内存分析。所有这些都是通过 jax.stages.Loweredjax.stages.Compiled 对象的方法提供的(例如,上面的 lowered.as_text()compiled.cost_analysis())。

这些方法旨在作为手动检查和调试的辅助工具,而不是作为可靠的可编程API。它们的可用性和输出因编译器、平台和运行时而异。这带来了两个重要的注意事项:

  1. 如果某些功能在JAX的当前后端不可用,那么它的方法将返回一些无意义的结果(类似于False)。例如,如果JAX底层的编译器不提供成本分析,那么compiled.cost_analysis()将返回None

  2. 如果某些功能可用,对于相应方法提供的内容仍然有非常有限的保证。返回值不需要在类型、结构或值上保持一致——跨JAX配置、后端/平台、版本,甚至方法的调用。JAX不能保证compiled.cost_analysis()在一天的输出会在第二天保持不变。

如有疑问,请参阅 jax.stages 的包API文档。

检查暂存计算#

本笔记顶部的列表中的第1阶段提到了特化和分阶段处理,在降低之前。JAX内部对函数针对其参数类型的特化概念并不总是内存中的具体数据结构。要显式构造JAX内部函数特化的视图,请参见 jax.make_jaxpr()