即时编译#

在本节中,我们将进一步探讨JAX的工作原理,以及如何使其表现出色。我们将讨论 jax.jit() 变换,它将对JAX Python函数进行即时(JIT)编译,以便在XLA中高效执行。

JAX 变换的工作原理#

在上一节中,我们讨论了 JAX 允许我们转换 Python 函数。JAX 通过将每个函数简化为一系列 基本 操作来实现这一点,每个操作代表一个基本的计算单元。

查看函数背后基本操作序列的一种方法是使用 jax.make_jaxpr()

import jax
import jax.numpy as jnp

global_list = []

def log2(x):
  global_list.append(x)
  ln_x = jnp.log(x)
  ln_2 = jnp.log(2.0)
  return ln_x / ln_2

print(jax.make_jaxpr(log2)(3.0))
{ lambda ; a:f32[]. let
    b:f32[] = log a
    c:f32[] = log 2.0
    d:f32[] = div b c
  in (d,) }

文档的 理解 Jaxprs 部分提供了关于上述输出含义的更多信息。

重要的是,注意到 jaxpr 并没有捕捉到函数中存在的副作用:其中没有任何内容对应于 global_list.append(x)。这是一个特性,而不是一个错误:JAX 变换被设计为理解无副作用(即函数式纯)的代码。如果 纯函数副作用 是不熟悉的术语,这在《🔪 JAX - The Sharp Bits 🔪: Pure Functions》中有更详细的解释。

不纯函数是危险的,因为在 JAX 变换下它们可能不会按预期行为运行;它们可能会静默失败,或者产生令人惊讶的下游错误,如泄露的 Tracer。此外,JAX 通常无法检测到何时存在副作用。(如果你想进行调试打印,请使用 jax.debug.print()。要以性能为代价表达一般的副作用,请参见 jax.experimental.io_callback()。要以性能为代价检查 Tracer 泄露,请使用 with jax.check_tracer_leaks())。

在追踪时,JAX 通过一个 追踪器 对象包装每个参数。这些追踪器在函数调用期间(这在常规 Python 中发生)记录所有在它们上执行的 JAX 操作。然后,JAX 使用这些追踪记录来重建整个函数。该重建的输出就是 jaxpr。由于追踪器不记录 Python 的副作用,它们不会出现在 jaxpr 中。然而,副作用仍然在追踪过程中发生。

注意:Python 的 print() 函数不是纯函数:文本输出是该函数的一个副作用。因此,任何 print() 调用只会在跟踪期间发生,并且不会出现在 jaxpr 中:

def log2_with_print(x):
  print("printed x:", x)
  ln_x = jnp.log(x)
  ln_2 = jnp.log(2.0)
  return ln_x / ln_2

print(jax.make_jaxpr(log2_with_print)(3.))
printed x: Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
{ lambda ; a:f32[]. let
    b:f32[] = log a
    c:f32[] = log 2.0
    d:f32[] = div b c
  in (d,) }

看到打印的 x 是一个 Traced 对象了吗?这就是 JAX 内部的工作机制。

Python 代码至少运行一次这一事实严格来说是实现细节,因此不应依赖于此。然而,理解这一点是有用的,因为当你在调试时可以利用它来打印出计算的中间值。

一个关键的理解是,jaxpr 捕获了函数在给定参数上的执行情况。例如,如果我们有一个 Python 条件语句,jaxpr 只会知道我们采取的分支:

def log2_if_rank_2(x):
  if x.ndim == 2:
    ln_x = jnp.log(x)
    ln_2 = jnp.log(2.0)
    return ln_x / ln_2
  else:
    return x

print(jax.make_jaxpr(log2_if_rank_2)(jax.numpy.array([1, 2, 3])))
{ lambda ; a:i32[3]. let  in (a,) }

即时编译一个函数#

如前所述,JAX 允许操作在 CPU/GPU/TPU 上使用相同的代码执行。让我们来看一个计算 缩放指数线性单元SELU)的例子,这是深度学习中常用的一种操作:

import jax
import jax.numpy as jnp

def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = jnp.arange(1000000)
%timeit selu(x).block_until_ready()
1.35 ms ± 155 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

上面的代码一次向加速器发送一个操作。这限制了XLA编译器优化我们函数的能力。

自然地,我们想要做的是给 XLA 编译器尽可能多的代码,以便它可以完全优化。为此,JAX 提供了 jax.jit() 变换,它将 JIT 编译一个 JAX 兼容的函数。下面的示例展示了如何使用 JIT 来加速之前的函数。

selu_jit = jax.jit(selu)

# Pre-compile the function before timing...
selu_jit(x).block_until_ready()

%timeit selu_jit(x).block_until_ready()
421 μs ± 40.4 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

以下是刚刚发生的事情:

  1. 我们将 selu_jit 定义为 selu 的编译版本。

  2. 我们在 x 上调用了一次 selu_jit。这是 JAX 进行追踪的地方——毕竟,它需要一些输入来包装在追踪器中。然后,jaxpr 使用 XLA 编译成针对你的 GPU 或 TPU 优化的非常高效的代码。最后,执行编译后的代码以满足调用。后续对 selu_jit 的调用将直接使用编译后的代码,完全跳过 Python 实现。(如果我们没有单独包含预热调用,一切仍然会工作,但编译时间将包含在基准测试中。它仍然会更快,因为我们在基准测试中运行了许多循环,但这将不是一个公平的比较。)

  3. 我们测试了编译版本的执行速度。(注意使用了 block_until_ready(),这是由于 JAX 的 异步调度 所必需的)。

为什么我们不能只是JIT一切?#

在看完上面的例子后,你可能会想知道我们是否应该简单地对每个函数应用 jax.jit()。要理解为什么不是这样,以及我们何时应该/不应该应用 jit,让我们先看看 JIT 不起作用的一些情况。

# Condition on value of x.

def f(x):
  if x > 0:
    return x
  else:
    return 2 * x

jax.jit(f)(10)  # Raises an error
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function f at /var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22572/2956679937.py:3 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError
# While loop conditioned on x and n.

def g(x, n):
  i = 0
  while i < n:
    i += 1
  return x + i

jax.jit(g)(10, 20)  # Raises an error
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function g at /var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22572/722961019.py:3 for jit. This concrete value was not available in Python because it depends on the value of the argument n.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

在这两种情况下,问题在于我们试图使用运行时值来调节程序的跟踪时间流。JIT 中的跟踪值,如这里的 xn,只能通过它们的静态属性(如 shapedtype)来影响控制流,而不能通过它们的值。有关 Python 控制流与 JAX 之间交互的更多细节,请参阅 🔪 JAX - The Sharp Bits 🔪: Control Flow

处理这个问题的一种方法是重写代码以避免对值进行条件判断。另一种方法是使用特殊的 宽松控制流jax.lax.cond()。然而,有时这是不可能或不切实际的。在这种情况下,你可以考虑只对函数的一部分进行即时编译(JIT)。例如,如果函数中最耗费计算的部分在循环内部,我们可以只对那部分内部进行即时编译(但请确保检查下一节关于缓存的内容,以避免自找麻烦):

# While loop conditioned on x and n with a jitted body.

@jax.jit
def loop_body(prev_i):
  return prev_i + 1

def g_inner_jitted(x, n):
  i = 0
  while i < n:
    i = loop_body(i)
  return x + i

g_inner_jitted(10, 20)
Array(30, dtype=int32, weak_type=True)

将参数标记为静态#

如果我们确实需要即时编译一个在输入值上有条件的函数,我们可以通过指定 static_argnumsstatic_argnames 来告诉 JAX 为特定输入使用一个不那么抽象的跟踪器。这样做的好处是生成的 jaxpr 和编译后的产物依赖于传递的特定值,因此 JAX 将不得不为每个新的指定静态输入值重新编译函数。只有在函数保证只会看到有限的一组静态值时,这才会是一个好的策略。

f_jit_correct = jax.jit(f, static_argnums=0)
print(f_jit_correct(10))
10
g_jit_correct = jax.jit(g, static_argnames=['n'])
print(g_jit_correct(10, 20))
30

在使用 jit 作为装饰器时指定此类参数,一个常见的模式是使用 Python 的 functools.partial()

from functools import partial

@partial(jax.jit, static_argnames=['n'])
def g_jit_decorated(x, n):
  i = 0
  while i < n:
    i += 1
  return x + i

print(g_jit_decorated(10, 20))
30

JIT 和缓存#

在第一次JIT调用的编译开销中,理解 jax.jit() 如何以及何时缓存之前的编译是有效使用它的关键。

假设我们定义 f = jax.jit(g)。当我们第一次调用 f 时,它将被编译,生成的 XLA 代码将被缓存。后续调用 f 时将重用缓存的代码。这就是 jax.jit 如何弥补编译的前期成本。

如果我们指定 static_argnums,那么缓存的代码将仅用于标记为静态的参数值相同的情况。如果其中任何一个发生变化,就会发生重新编译。如果有许多值,那么您的程序可能会花费比逐个执行操作更多的时间进行编译。

避免在循环或其他Python作用域内定义的临时函数上调用 jax.jit()。在大多数情况下,JAX将能够在后续调用 jax.jit() 时使用已编译并缓存的函数。然而,由于缓存依赖于函数的哈希值,当等效函数被重新定义时,这会成为一个问题。这将在每次循环中导致不必要的编译:

from functools import partial

def unjitted_loop_body(prev_i):
  return prev_i + 1

def g_inner_jitted_partial(x, n):
  i = 0
  while i < n:
    # Don't do this! each time the partial returns
    # a function with different hash
    i = jax.jit(partial(unjitted_loop_body))(i)
  return x + i

def g_inner_jitted_lambda(x, n):
  i = 0
  while i < n:
    # Don't do this!, lambda will also return
    # a function with a different hash
    i = jax.jit(lambda x: unjitted_loop_body(x))(i)
  return x + i

def g_inner_jitted_normal(x, n):
  i = 0
  while i < n:
    # this is OK, since JAX can find the
    # cached, compiled function
    i = jax.jit(unjitted_loop_body)(i)
  return x + i

print("jit called in a loop with partials:")
%timeit g_inner_jitted_partial(10, 20).block_until_ready()

print("jit called in a loop with lambdas:")
%timeit g_inner_jitted_lambda(10, 20).block_until_ready()

print("jit called in a loop with caching:")
%timeit g_inner_jitted_normal(10, 20).block_until_ready()
jit called in a loop with partials:
141 ms ± 6.28 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
jit called in a loop with lambdas:
163 ms ± 21 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
jit called in a loop with caching:
652 μs ± 46.5 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)