如何在 JAX 中思考#

在 Colab 中打开 在 Kaggle 中打开

JAX 提供了一个简单而强大的 API 来编写加速的数值代码,但在 JAX 中有效地工作有时需要额外的考量。本文档旨在帮助您从基础理解 JAX 的工作原理,以便您可以更有效地使用它。

JAX 与 NumPy#

关键概念:

  • JAX 提供了一个受 NumPy 启发的接口以便于使用。

  • 通过鸭子类型,JAX 数组通常可以作为 NumPy 数组的直接替代品使用。

  • 与 NumPy 数组不同,JAX 数组始终是不可变的。

NumPy 提供了一个广为人知且强大的 API 用于处理数值数据。为了方便,JAX 提供了 jax.numpy,该接口与 NumPy API 密切相似,并提供了便捷的 JAX 入门。几乎可以用 numpy 完成的任何操作都可以用 jax.numpy 完成:

import matplotlib.pyplot as plt
import numpy as np

x_np = np.linspace(0, 10, 1000)
y_np = 2 * np.sin(x_np) * np.cos(x_np)
plt.plot(x_np, y_np);
../_images/ed74117ce798d02f04559155709be03bef63cfa850e6af47b918884ed471961f.png
import jax.numpy as jnp

x_jnp = jnp.linspace(0, 10, 1000)
y_jnp = 2 * jnp.sin(x_jnp) * jnp.cos(x_jnp)
plt.plot(x_jnp, y_jnp);
../_images/7566bdaf92d2f9beda43e4d3ddee916a69d996ec12b671de077d49428fd54fd2.png

代码块除了将 np 替换为 jnp 外是相同的,结果也是相同的。正如我们所看到的,JAX 数组通常可以直接替代 NumPy 数组用于绘图等操作。

数组本身作为不同的 Python 类型实现:

type(x_np)
numpy.ndarray
type(x_jnp)
jaxlib.xla_extension.ArrayImpl

Python的鸭子类型允许在许多地方交替使用JAX数组和NumPy数组。

然而,JAX和NumPy数组之间有一个重要的区别:JAX数组是不可变的,这意味着一旦创建,其内容无法更改。

以下是一个在NumPy中变更数组的示例:

# NumPy:可变数组
x = np.arange(10)
x[0] = 10
print(x)
[10  1  2  3  4  5  6  7  8  9]

在JAX中,等效操作会导致错误,因为JAX数组是不可变的:

%xmode minimal
Exception reporting mode: Minimal
# JAX:不可变数组
x = jnp.arange(10)
x[0] = 10
TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html

为了更新单个元素,JAX 提供了一个 索引更新语法,它返回一个更新后的副本:

y = x.at[0].set(10)
print(x)
print(y)
[0 1 2 3 4 5 6 7 8 9]
[10  1  2  3  4  5  6  7  8  9]

NumPy, lax 和 XLA:JAX API 层次#

关键概念:

  • jax.numpy 是一个高级封装,提供了一个熟悉的接口。

  • jax.lax 是一个较低级的 API,更严格且通常更强大。

  • 所有 JAX 操作都是基于 XLA(加速线性代数编译器)中的操作实现的。

如果你查看 jax.numpy 的源代码,你会发现所有的操作最终都是通过在 jax.lax 中定义的函数来实现的。你可以将 jax.lax 看作是一个更严格但通常更强大的多维数组操作 API。

例如,虽然 jax.numpy 会隐式提升参数以允许混合数据类型之间的操作,但 jax.lax 不会:

import jax.numpy as jnp
jnp.add(1, 1.0)  # jax.numpy API 隐式地支持混合类型。
Array(2., dtype=float32, weak_type=True)
from jax import lax
lax.add(1, 1.0)  # jax.lax API 需要显式类型提升。
ValueError: Cannot lower jaxpr with verifier errors:
	op requires the same element type for all operands and results
		at loc("jit(add)/jit(main)/add"(callsite("<module>"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22855/3496455845.py":2:0) at callsite("InteractiveShell.run_code"("/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3577:20) at callsite("InteractiveShell.run_ast_nodes"("/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3517:19) at callsite("InteractiveShell.run_cell_async"("/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3334:29) at callsite("_pseudo_sync_runner"("/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/async_helpers.py":128:8) at callsite("InteractiveShell._run_cell"("/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3130:21) at callsite("InteractiveShell.run_cell"("/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3075:21) at callsite("ZMQInteractiveShell.run_cell"("/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/ipykernel/zmqshell.py":549:15) at callsite("IPythonKernel.do_execute"("/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/ipykernel/ipkernel.py":449:26) at "Kernel.execute_request"("/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/ipykernel/kernelbase.py":778:28))))))))))))
Define JAX_DUMP_IR_TO to dump the module.

如果直接使用 jax.lax,在这种情况下,您必须显式地进行类型提升:

lax.add(jnp.float32(1), 1.0)
Array(2., dtype=float32)

与此同时,jax.lax 还提供了比 NumPy 支持的更通用的操作的高效 API。

例如,考虑一维卷积,可以在 NumPy 中这样表示:

x = jnp.array([1, 2, 1])
y = jnp.ones(10)
jnp.convolve(x, y)
Array([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32)

在内部,这个NumPy操作被转换为一个更通用的卷积,由lax.conv_general_dilated实现:

from jax import lax
result = lax.conv_general_dilated(
    x.reshape(1, 1, 3).astype(float),  # 注意:明确推广
    y.reshape(1, 1, 10),
    window_strides=(1,),
    padding=[(len(y) - 1, len(y) - 1)])  # equivalent of padding='full' in NumPy
result[0, 0]
Array([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32)

这是一个批量卷积操作,旨在提高在深度神经网络中常用的卷积类型的效率。它需要更多的样板代码,但比NumPy提供的卷积更灵活和可扩展(有关JAX卷积的更多详细信息,请参见 JAX中的卷积)。

从本质上讲,所有 jax.lax 操作都是对XLA中操作的Python封装;例如,这里的卷积实现由 XLA:ConvWithGeneralPadding 提供。 每个JAX操作最终都是用这些基本的XLA操作来表示的,这使得即时编译(JIT)得以实现。

JIT还是不JIT#

关键概念:

  • 默认情况下,JAX一次执行一个操作,按顺序进行。

  • 通过使用即时编译(JIT)装饰器,可以将一序列操作一起优化并一次性运行。

  • 并不是所有JAX代码都可以进行JIT编译,因为它要求数组的形状在编译时是静态且已知的。

所有JAX操作都是通过XLA来表达,这使得JAX能够利用XLA编译器高效地执行代码块。

例如,考虑这个规范化2D矩阵行的函数,它是通过jax.numpy操作表达的:

import jax.numpy as jnp

def norm(X):
  X = X - X.mean(0)
  return X / X.std(0)

可以使用 jax.jit 转换来创建该函数的即时编译版本:

from jax import jit
norm_compiled = jit(norm)

此函数返回与原始结果相同的结果,直到标准浮点精度为止:

np.random.seed(1701)
X = jnp.array(np.random.rand(10000, 10))
np.allclose(norm(X), norm_compiled(X), atol=1E-6)
True

但由于编译(包括操作的融合、避免分配临时数组以及其他诸多技巧),在JIT编译的情况下,执行时间可以快几个数量级(注意使用block_until_ready()来考虑JAX的异步分派):

%timeit norm(X).block_until_ready()
%timeit norm_compiled(X).block_until_ready()
297 μs ± 6.59 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
261 μs ± 3.4 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

话虽如此,jax.jit确实有一些限制:特别是,它要求所有数组具有静态形状。这意味着某些JAX操作与JIT编译不兼容。

例如,此操作可以在逐操作模式下执行:

def get_negatives(x):
  return x[x < 0]

x = jnp.array(np.random.randn(10))
get_negatives(x)
Array([-0.10570311, -0.59403396, -0.8680282 , -0.23489487], dtype=float32)

但如果您尝试在jit模式下执行它,则会返回错误:

jit(get_negatives)(x)
NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[10])

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError

这是因为该函数生成的数组形状在编译时无法确定:输出的大小取决于输入数组的值,因此与 JIT 不兼容。

JIT 机制:追踪和静态变量#

关键概念:

  • JIT 和其他 JAX 转换通过 追踪 一个函数来确定它对特定形状和类型输入的影响。

  • 你不希望被追踪的变量可以标记为 静态

要有效使用 jax.jit,理解它的工作原理是很有用的。让我们在一个 JIT 编译的函数中放置一些 print() 语句,然后调用这个函数:

@jit
def f(x, y):
  print("Running f():")
  print(f"  x = {x}")
  print(f"  y = {y}")
  result = jnp.dot(x + 1, y + 1)
  print(f"  result = {result}")
  return result

x = np.random.randn(3, 4)
y = np.random.randn(4)
f(x, y)
Running f():
  x = Traced<ShapedArray(float32[3,4])>with<DynamicJaxprTrace(level=1/0)>
  y = Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=1/0)>
  result = Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=1/0)>
Array([0.25773212, 5.3623195 , 5.4032435 ], dtype=float32)

注意到打印语句会执行,但它打印的不是我们传递给函数的数据,而是代表它们的tracer对象。

这些tracer对象是jax.jit用来提取由函数指定的操作序列的。基本的tracer是代表者,它们编码了数组的形状数据类型,但与值无关。这个记录的计算序列可以在XLA中高效地应用于具有相同形状和数据类型的新输入,而无需重新执行Python代码。

当我们在匹配的输入上再次调用编译后的函数时,不需要重新编译,并且什么也不会被打印,因为结果是在编译的XLA中计算的,而不是在Python中:

x2 = np.random.randn(3, 4)
y2 = np.random.randn(4)
f(x2, y2)
Array([1.4344587, 4.3004417, 7.989701 ], dtype=float32)

提取的操作序列编码在一个 JAX 表达式中,简称为 jaxpr。您可以使用 jax.make_jaxpr 转换来查看 jaxpr:

from jax import make_jaxpr

def f(x, y):
  return jnp.dot(x + 1, y + 1)

make_jaxpr(f)(x, y)
{ lambda ; a:f32[3,4] b:f32[4]. let
    c:f32[3,4] = add a 1.0
    d:f32[4] = add b 1.0
    e:f32[3] = dot_general[
      dimension_numbers=(([1], [0]), ([], []))
      preferred_element_type=float32
    ] c d
  in (e,) }

注意这一点的一个后果:因为JIT编译是在没有数组内容信息的情况下完成的,所以函数中的控制流语句不能依赖于被追踪的值。例如,这将失败:

@jit
def f(x, neg):
  return -x if neg else x

f(1, True)
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function f at /var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22855/3703183757.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument neg.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

如果有一些变量您不希望被追踪,可以将它们标记为静态,以便于JIT编译:

from functools import partial

@partial(jit, static_argnums=(1,))
def f(x, neg):
  return -x if neg else x

f(1, True)
Array(-1, dtype=int32, weak_type=True)

请注意,使用不同的静态参数调用JIT编译的函数将导致重新编译,因此该函数仍然按预期工作:

f(1, False)
Array(1, dtype=int32, weak_type=True)

理解哪些值和操作将是静态的,哪些将被追踪,是有效使用 jax.jit 的关键部分。

静态操作与跟踪操作#

关键概念:

  • 就像值可以是静态的或被跟踪的,操作也可以是静态的或被跟踪的。

  • 静态操作在Python中是在编译时评估的;被跟踪的操作是在XLA中在运行时编译和评估的。

  • 对于希望保持静态的操作,请使用numpy;对于希望被跟踪的操作,请使用jax.numpy

这种静态值与跟踪值之间的区别使得思考如何保持静态值静态变得重要。考虑这个函数:

import jax.numpy as jnp
from jax import jit

@jit
def f(x):
  return x.reshape(jnp.array(x.shape).prod())

x = jnp.ones((2, 3))
f(x)
TypeError: Shapes must be 1D sequences of concrete values of integer type, got [Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>].
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function f at /var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22855/893595691.py:4 for jit. This value became a tracer due to JAX operations on these lines:

  operation a:i32[2] = convert_element_type[new_dtype=int32 weak_type=False] b
    from line /var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22855/893595691.py:6:19 (f)

这会导致一个错误,错误信息指出找到了一个跟踪器,而不是一个具体的整数类型的一维序列。让我们在函数中添加一些打印语句,以便理解为什么会发生这种情况:

@jit
def f(x):
  print(f"x = {x}")
  print(f"x.shape = {x.shape}")
  print(f"jnp.array(x.shape).prod() = {jnp.array(x.shape).prod()}")
  # 注释掉这一行以避免错误:
  # 返回 x 的形状重塑为 jnp.array(x.shape).prod()。

f(x)
x = Traced<ShapedArray(float32[2,3])>with<DynamicJaxprTrace(level=1/0)>
x.shape = (2, 3)
jnp.array(x.shape).prod() = Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>

请注意,尽管 x 是被跟踪的,x.shape 是一个静态值。然而,当我们对这个静态值使用 jnp.arrayjnp.prod 时,它就变成了一个被跟踪的值,此时它不能在像 reshape() 这样需要静态输入的函数中使用(回想一下:数组的形状必须是静态的)。

一个有用的模式是对应该是静态的操作(即在编译时完成的操作)使用 numpy,对应该被跟踪的操作(即在运行时编译和执行的操作)使用 jax.numpy。对于这个函数,它可能看起来像这样:

from jax import jit
import jax.numpy as jnp
import numpy as np

@jit
def f(x):
  return x.reshape((np.prod(x.shape),))

f(x)
Array([1., 1., 1., 1., 1., 1.], dtype=float32)

因此,在JAX程序中,一个标准的约定是 import numpy as npimport jax.numpy as jnp,这样可以同时使用这两种接口,以更细致地控制操作是以静态方式(使用 numpy,在编译时执行一次)还是以跟踪方式(使用 jax.numpy,在运行时优化)进行的。