常见问题解答 (FAQ)#

我们正在收集常见问题的答案。欢迎贡献!

jit 改变了我的函数的行为#

如果你有一个Python函数在使用 jax.jit() 后改变了行为,可能你的函数使用了全局状态,或者有副作用。在下面的代码中,impure_func 使用了全局变量 y,并且由于 print 产生了副作用:

y = 0

# @jit   # Different behavior with jit
def impure_func(x):
  print("Inside:", y)
  return x + y

for y in range(3):
  print("Result:", impure_func(y))

没有 jit 时,输出为:

Inside: 0
Result: 0
Inside: 1
Result: 2
Inside: 2
Result: 4

并且使用 jit 时,它是:

Inside: 0
Result: 0
Result: 1
Result: 2

对于 jax.jit(),函数首先使用 Python 解释器执行一次,此时会打印 Inside,并观察到 y 的第一个值。然后,函数被编译并缓存,之后多次执行,输入不同的 x 值,但 y 的第一个值保持不变。

附加阅读:

jit 改变了输出的精确数值#

有时用户会对用 jit() 包装一个函数可以改变函数的输出这一事实感到惊讶。例如:

>>> from jax import jit
>>> import jax.numpy as jnp
>>> def f(x):
...   return jnp.log(jnp.sqrt(x))
>>> x = jnp.pi
>>> print(f(x))
0.572365
>>> print(jit(f)(x))
0.5723649

这种输出上的细微差异来自于XLA编译器内部的优化:在编译过程中,XLA有时会重新排列或省略某些操作,以使整体计算更加高效。

在这种情况下,XLA 利用对数的性质,将 log(sqrt(x)) 替换为 0.5 * log(x),这是一个在数学上等价的表达式,但计算效率比原表达式更高。输出结果的差异源于浮点运算只是实数运算的一个近似,因此计算同一表达式的不同方式可能会产生微妙的差异。

其他时候,XLA 的优化可能会导致更加显著的差异。考虑以下示例:

>>> def f(x):
...   return jnp.log(jnp.exp(x))
>>> x = 100.0
>>> print(f(x))
inf
>>> print(jit(f)(x))
100.0

在非JIT编译的逐操作模式下,结果是 inf 因为 jnp.exp(x) 溢出并返回 inf。然而,在JIT下,XLA识别到 logexp 的逆运算,并从编译函数中移除了这些操作,直接返回输入。在这种情况下,JIT编译产生了一个更准确的实数结果的浮点近似值。

不幸的是,XLA 的代数简化完整列表没有得到很好的文档化,但如果你熟悉 C++ 并且对 XLA 编译器所做的优化类型感到好奇,你可以在源代码中查看它们:algebraic_simplifier.cc

jit 装饰的函数编译速度非常慢#

如果你的 jit 装饰函数在第一次调用时需要几十秒(或更长时间!)来运行,但在再次调用时执行速度很快,那么 JAX 正在花费大量时间来追踪或编译你的代码。

这通常是一个迹象,表明调用你的函数在JAX的内部表示中生成了大量的代码,通常是因为它大量使用了Python控制流,如 for 循环。对于少量的循环迭代,Python是可以的,但如果你需要 许多 循环迭代,你应该重写你的代码以利用JAX的 结构化控制流原语 (如 lax.scan()),或者避免用 jit 包装循环(你仍然可以在循环 内部 使用 jit 装饰的函数)。

如果你不确定这是否是问题所在,你可以尝试在你的函数上运行 jax.make_jaxpr()。如果输出长达数百或数千行,你可以预期编译速度会变慢。

有时,如何重写代码以避免Python循环并不明显,因为您的代码使用了多种不同形状的数组。在这种情况下,推荐的解决方案是使用 jax.numpy.where() 等函数,在具有固定形状的填充数组上进行计算。

如果你的函数由于其他原因编译速度慢,请在GitHub上提交一个问题。

如何将 jit 用于方法?#

大多数 jax.jit() 的例子都是关于装饰独立的 Python 函数的,但装饰类中的方法会引入一些复杂性。例如,考虑以下简单的类,我们在方法上使用了标准的 jit() 注解:

>>> import jax.numpy as jnp
>>> from jax import jit

>>> class CustomClass:
...   def __init__(self, x: jnp.ndarray, mul: bool):
...     self.x = x
...     self.mul = mul
...
...   @jit  # <---- How to do this correctly?
...   def calc(self, y):
...     if self.mul:
...       return self.x * y
...     return y

然而,当你尝试调用此方法时,这种方法将导致错误:

>>> c = CustomClass(2, True)
>>> c.calc(3)  
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
  File "<stdin>", line 1, in <module
TypeError: Argument '<CustomClass object at 0x7f7dd4125890>' of type <class 'CustomClass'> is not a valid JAX type.

问题在于函数的第一个参数是 self,其类型为 CustomClass,而 JAX 不知道如何处理这种类型。在这种情况下,我们可能使用三种基本策略,下面将讨论它们。

策略 1: JIT 编译辅助函数#

最直接的方法是创建一个类外部的辅助函数,它可以以正常方式进行JIT装饰。例如:

>>> from functools import partial

>>> class CustomClass:
...   def __init__(self, x: jnp.ndarray, mul: bool):
...     self.x = x
...     self.mul = mul
...
...   def calc(self, y):
...     return _calc(self.mul, self.x, y)

>>> @partial(jit, static_argnums=0)
... def _calc(mul, x, y):
...   if mul:
...     return x * y
...   return y

结果将按预期工作:

>>> c = CustomClass(2, True)
>>> print(c.calc(3))
6

这种方法的好处是它简单、明确,并且避免了教 JAX 如何处理 CustomClass 类型的对象。然而,你可能希望将所有方法逻辑保持在同一个地方。

策略 2:将 self 标记为静态#

另一种常见模式是使用 static_argnumsself 参数标记为静态。但必须小心操作以避免意外结果。你可能会倾向于简单地这样做:

>>> class CustomClass:
...   def __init__(self, x: jnp.ndarray, mul: bool):
...     self.x = x
...     self.mul = mul
...
...   # WARNING: this example is broken, as we'll see below. Don't copy & paste!
...   @partial(jit, static_argnums=0)
...   def calc(self, y):
...     if self.mul:
...       return self.x * y
...     return y

如果你调用这个方法,它将不再引发错误:

>>> c = CustomClass(2, True)
>>> print(c.calc(3))
6

然而,有一个问题:如果你在第一次方法调用后修改了对象,后续的方法调用可能会返回错误的结果:

>>> c.mul = False
>>> print(c.calc(3))  # Should print 3
6

为什么是这样?当你将一个对象标记为静态时,它将被有效地用作JIT内部编译缓存中的字典键,这意味着其哈希值(即 hash(obj))相等性(即 obj1 == obj2)和对象标识(即 obj1 is obj2)将被假定为具有一致的行为。自定义对象的默认 __hash__ 是其对象ID,因此JAX无法知道一个被修改的对象应该触发重新编译。

你可以通过为你的对象定义适当的 __hash____eq__ 方法来部分解决这个问题;例如:

>>> class CustomClass:
...   def __init__(self, x: jnp.ndarray, mul: bool):
...     self.x = x
...     self.mul = mul
...
...   @partial(jit, static_argnums=0)
...   def calc(self, y):
...     if self.mul:
...       return self.x * y
...     return y
...
...   def __hash__(self):
...     return hash((self.x, self.mul))
...
...   def __eq__(self, other):
...     return (isinstance(other, CustomClass) and
...             (self.x, self.mul) == (other.x, other.mul))

(参见 object.__hash__() 文档以获取更多关于重写 __hash__ 时要求的讨论).

这应该能与JIT和其他转换正确工作 只要你从不改变你的对象。作为哈希键的对象的变异会导致几个微妙的问题,这就是为什么例如可变的Python容器(如 dict, list)不定义 __hash__,而它们的不可变对应物(如 tuple)则定义。

如果你的类依赖于就地突变(例如在其方法中设置 self.attr = ...),那么你的对象实际上并不是“静态”的,将其标记为静态可能会导致问题。幸运的是,这种情况还有另一种选择。

策略 3:将 CustomClass 设为 PyTree#

最灵活的正确JIT编译类方法的方法是将类型注册为自定义PyTree对象;参见 扩展 pytrees 。这使您能够精确指定类的哪些部分应被视为静态,哪些部分应被视为动态。以下是它的可能样子:

>>> class CustomClass:
...   def __init__(self, x: jnp.ndarray, mul: bool):
...     self.x = x
...     self.mul = mul
...
...   @jit
...   def calc(self, y):
...     if self.mul:
...       return self.x * y
...     return y
...
...   def _tree_flatten(self):
...     children = (self.x,)  # arrays / dynamic values
...     aux_data = {'mul': self.mul}  # static values
...     return (children, aux_data)
...
...   @classmethod
...   def _tree_unflatten(cls, aux_data, children):
...     return cls(*children, **aux_data)

>>> from jax import tree_util
>>> tree_util.register_pytree_node(CustomClass,
...                                CustomClass._tree_flatten,
...                                CustomClass._tree_unflatten)

这当然更复杂,但它解决了上述简单方法所关联的所有问题:

>>> c = CustomClass(2, True)
>>> print(c.calc(3))
6

>>> c.mul = False  # mutation is detected
>>> print(c.calc(3))
3

>>> c = CustomClass(jnp.array(2), True)  # non-hashable x is supported
>>> print(c.calc(3))
6

只要你的 tree_flattentree_unflatten 函数正确处理了类中的所有相关属性,你应该能够直接将此类型的对象用作 JIT 编译函数的参数,而无需任何特殊注解。

控制数据和计算在设备上的放置#

首先,我们来看一下JAX中数据和计算放置的原理。

在 JAX 中,计算遵循数据放置。JAX 数组有两个放置属性:1) 数据所在的设备;2) 数据是否**已提交**到设备(数据有时被称为对设备*粘滞*)。

默认情况下,JAX 数组未提交到默认设备(jax.devices()[0]),该设备默认是第一个 GPU 或 TPU。如果没有 GPU 或 TPU,jax.devices()[0] 则是 CPU。可以使用 jax.default_device() 上下文管理器临时覆盖默认设备,或者通过设置环境变量 JAX_PLATFORMS 或 absl 标志 --jax_platforms 为 “cpu”、”gpu” 或 “tpu” 来为整个进程设置默认设备(JAX_PLATFORMS 也可以是一个平台列表,按优先顺序确定哪些平台可用)。

>>> from jax import numpy as jnp
>>> print(jnp.ones(3).devices())  
{CudaDevice(id=0)}

涉及未提交数据的计算在默认设备上执行,结果在默认设备上未提交。

数据也可以使用带有 device 参数的 jax.device_put() 显式放置在设备上,在这种情况下,数据将**提交**到设备:

>>> import jax
>>> from jax import device_put
>>> arr = device_put(1, jax.devices()[2])  
>>> print(arr.devices())  
{CudaDevice(id=2)}

涉及某些已提交输入的计算将在已提交的设备上进行,结果也将提交到同一设备。如果在已提交到多个设备的参数上调用操作,将引发错误。

你也可以在没有 device 参数的情况下使用 jax.device_put()。如果数据已经在一个设备上(已提交或未提交),则保持原样。如果数据不在任何设备上——即它是一个常规的 Python 或 NumPy 值——它将被放置在默认设备上,且未提交。

Jitted 函数的行为类似于任何其他基本操作——它们将遵循数据,如果在多个设备上提交的数据上调用,将显示错误。

(在2021年3月 PR #6002 之前,数组常量的创建存在一些惰性,因此 jax.device_put(jnp.zeros(...), jax.devices()[1]) 或类似的操作实际上会在 jax.devices()[1] 上创建零数组,而不是在默认设备上创建数组然后移动它。但为了简化实现,这个优化被移除了。)

(截至2020年4月,jax.jit() 有一个 device 参数,该参数影响设备分配。该参数是实验性的,可能会被移除或更改,不建议使用。)

对于一个详细的示例,我们建议阅读 multi_device_test.py 中的 test_computation_follows_data

基准测试 JAX 代码#

你刚刚将一个复杂的函数从 NumPy/SciPy 移植到了 JAX。这真的加快了速度吗?

在使用 JAX 测量代码速度时,请记住与 NumPy 相比的这些重要差异:

  1. JAX 代码是即时(JIT)编译的。 大多数用 JAX 编写的代码可以以支持 JIT 编译的方式编写,这可以使它运行得 快得多 (参见 To JIT or not to JIT)。为了从 JAX 中获得最大性能,你应该在你的最外层函数调用上应用 jax.jit()

    请记住,第一次运行 JAX 代码时,它会较慢,因为它正在被编译。即使你不在自己的代码中使用 jit ,这也是正确的,因为 JAX 的内置函数也是即时编译的。

  2. JAX 具有异步调度。 这意味着你需要调用 .block_until_ready() 来确保计算实际上已经发生(参见 异步分发)。

  3. JAX 默认只使用 32 位数据类型。 你可能希望在 NumPy 中显式使用 32 位数据类型,或在 JAX 中启用 64 位数据类型(参见 双精度(64 位))以进行公平比较。

  4. 在CPU和加速器之间传输数据需要时间。 如果你只想测量评估一个函数所需的时间,你可能希望先将数据传输到你想要运行它的设备上(参见 控制数据和计算在设备上的放置)。

以下是如何将所有这些技巧组合成一个微基准测试的示例,用于比较 JAX 和 NumPy,利用 IPython 的便捷 %time 和 %timeit magics:

import numpy as np
import jax.numpy as jnp
import jax

def f(x):  # function we're benchmarking (works in both NumPy & JAX)
  return x.T @ (x - x.mean(axis=0))

x_np = np.ones((1000, 1000), dtype=np.float32)  # same as JAX default dtype
%timeit f(x_np)  # measure NumPy runtime

%time x_jax = jax.device_put(x_np)  # measure JAX device transfer time
f_jit = jax.jit(f)
%time f_jit(x_jax).block_until_ready()  # measure JAX compilation time
%timeit f_jit(x_jax).block_until_ready()  # measure JAX runtime

当在 Colab 中使用 GPU 运行时,我们看到:

  • 在CPU上,NumPy每次评估需要16.2毫秒

  • JAX 需要 1.26 毫秒将 NumPy 数组复制到 GPU 上

  • JAX 需要 193 毫秒来编译函数

  • 在GPU上,JAX每次评估需要485微秒

在这种情况下,我们看到一旦数据被传输并且函数被编译,JAX 在 GPU 上的重复评估速度大约快了 30 倍。

这是一个公平的比较吗?也许吧。最终重要的是运行完整应用程序的性能,这不可避免地包括一定量的数据传输和编译。此外,我们小心地选择了足够大的数组(1000x1000)和足够密集的计算(@ 运算符正在进行矩阵-矩阵乘法),以摊销JAX/加速器与NumPy/CPU增加的开销。例如,如果我们把这个例子改为使用10x10的输入,JAX/GPU的运行速度比NumPy/CPU慢10倍(100 µs vs 10 µs)。

JAX 比 NumPy 更快吗?#

用户经常试图通过此类基准测试来回答的一个问题是 JAX 是否比 NumPy 更快;由于这两个包的差异,没有简单的答案。

广义上讲:

  • NumPy 操作是即时执行的,同步的,并且仅在 CPU 上执行。

  • JAX 操作可以在编译后执行(如果在 jit() 内部),也可以立即执行;它们是异步分派的(参见 异步分发);并且它们可以在 CPU、GPU 或 TPU 上执行,每种设备都有截然不同且不断演变的性能特征。

这些架构上的差异使得在 NumPy 和 JAX 之间进行有意义的直接基准比较变得困难。

此外,这些差异导致了不同软件包之间的工程重点不同:例如,NumPy在减少单个数组操作的每次调用分派开销方面投入了大量精力,因为在NumPy的计算模型中,这种开销是无法避免的。而JAX则有几种避免分派开销的方法(例如JIT编译、异步分派、批处理转换等),因此减少每次调用的开销并不是优先考虑的事项。

记住这些,总结来说:如果你在CPU上对单个数组操作进行微基准测试,通常可以预期NumPy由于其较低的每次操作调度开销而优于JAX。如果你在GPU或TPU上运行代码,或者在CPU上对更复杂的JIT编译操作序列进行基准测试,通常可以预期JAX优于NumPy。

不同种类的 JAX 值#

在函数转换过程中,JAX 将一些函数参数替换为特殊的跟踪值。

如果你使用 print 语句,你可以看到这个:

def func(x):
  print(x)
  return jnp.cos(x)

res = jax.jit(func)(0.)

上述代码确实返回了正确的值 1.,但它也为 x 的值打印了 Traced<ShapedArray(float32[])>。通常情况下,JAX 以透明的方式在内部处理这些跟踪值,例如,在用于实现 jax.numpy 函数的数值 JAX 原语中。这就是为什么 jnp.cos 在上面的例子中可以工作。

更准确地说,对于JAX转换函数的参数引入了一个 tracer 值,除了那些由特殊参数标识的参数,例如 jax.jit()static_argnumsjax.pmap()static_broadcasted_argnums。通常,涉及至少一个tracer值的计算将产生一个tracer值。除了tracer值之外,还有 regular Python值:这些值是在JAX转换之外计算的,或者来自上述某些JAX转换的静态参数,或者仅从其他regular Python值计算得出。在没有JAX转换的情况下,这些值在各处使用。

一个跟踪值携带一个 抽象 值,例如,ShapedArray 包含数组的形状和数据类型的信息。我们在这里将这些跟踪值称为 抽象跟踪值。一些跟踪值,例如,那些为自动微分变换的参数引入的跟踪值,携带 ConcreteArray 抽象值,这些抽象值实际上包含常规数组数据,并用于解析条件等。我们在这里将这些跟踪值称为 具体跟踪值。从这些具体跟踪值计算出的跟踪值,可能与常规值结合,结果是具体跟踪值。一个 具体值 要么是一个常规值,要么是一个具体跟踪值。

大多数情况下,从跟踪器值计算出的值本身也是跟踪器值。只有极少数例外,当计算可以完全使用跟踪器携带的抽象值完成时,结果可以是常规值。例如,获取具有 ShapedArray 抽象值的跟踪器的形状。另一个例子是当显式地将具体跟踪器值转换为常规类型时,例如 int(x)x.astype(float)。另一种情况是 bool(x),当具体性使其可能时,它会生成一个 Python bool。这种情况尤为突出,因为它在控制流中经常出现。

以下是如何通过转换引入抽象或具体的示踪剂:

  • jax.jit() 为所有位置参数引入 抽象追踪器,除了那些由 static_argnums 表示的参数,它们保持常规值。

  • jax.pmap(): 为所有位置参数引入**抽象追踪器**,除了那些由 static_broadcasted_argnums 表示的参数。

  • jax.vmap(), jax.make_jaxpr(), xla_computation(): 为所有位置参数引入 抽象追踪器

  • jax.jvp()jax.grad() 为所有位置参数引入 具体跟踪器。当这些变换位于外部变换内且实际参数本身是抽象跟踪器时,则自动微分变换引入的跟踪器也是抽象跟踪器。

  • 所有高阶控制流原语(lax.cond()lax.while_loop()lax.fori_loop()lax.scan())在处理函数时都会引入 抽象追踪器,无论是否正在进行 JAX 变换。

当你有只能操作常规Python值的代码时,所有这些都是相关的,例如基于数据的控制流条件代码:

def divide(x, y):
  return x / y if y >= 1. else 0.

如果我们想应用 jax.jit(),我们必须确保指定 static_argnums=1 以确保 y 保持为一个常规值。这是由于布尔表达式 y >= 1.,它需要具体值(常规值或跟踪器)。如果我们显式地写 bool(y >= 1.),或者 int(y),或者 float(y),同样的情况也会发生。

有趣的是,jax.grad(divide)(3., 2.) 可以工作,因为 jax.grad() 使用了具体的追踪器,并使用 y 的具体值来解析条件。

Buffer donation#

当 JAX 执行计算时,它会在设备上为所有输入和输出使用缓冲区。如果你知道某个输入在计算后不再需要,并且它的形状和元素类型与某个输出匹配,你可以指定希望将相应的输入缓冲区捐赠以保存输出。这将通过捐赠缓冲区的大小来减少执行所需的内存。

如果你有类似以下的模式,你可以使用缓冲区捐赠:

params, state = jax.pmap(update_fn, donate_argnums=(0, 1))(params, state)

你可以把这看作是对你的不可变 JAX 数组进行内存高效的功能性更新的方法。在计算的边界内,XLA 可以为你进行这种优化,但在 jit/pmap 边界处,你需要向 XLA 保证在调用捐赠函数后不会使用捐赠的输入缓冲区。

您可以通过在函数 jax.jit()jax.pjit()jax.pmap() 中使用 donate_argnums 参数来实现这一点。此参数是一个索引序列(基于0),指向位置参数列表:

def add(x, y):
  return x + y

x = jax.device_put(np.ones((2, 3)))
y = jax.device_put(np.ones((2, 3)))
# Execute `add` with donation of the buffer for `y`. The result has
# the same shape and type as `y`, so it will share its buffer.
z = jax.jit(add, donate_argnums=(1,))(x, y)

请注意,当前在使用关键字参数调用您的函数时,这不会生效!以下代码不会捐赠任何缓冲区:

params, state = jax.pmap(update_fn, donate_argnums=(0, 1))(params=params, state=state)

如果一个参数的缓冲区被捐赠,并且它是一个 pytree,那么它的所有组件的缓冲区都会被捐赠:

def add_ones(xs: List[Array]):
  return [x + 1 for x in xs]

xs = [jax.device_put(np.ones((2, 3))), jax.device_put(np.ones((3, 4)))]
# Execute `add_ones` with donation of all the buffers for `xs`.
# The outputs have the same shape and type as the elements of `xs`,
# so they will share those buffers.
z = jax.jit(add_ones, donate_argnums=0)(xs)

不允许捐赠一个随后在计算中使用的缓冲区,JAX 会报错,因为 y 的缓冲区在捐赠后变得无效:

# Donate the buffer for `y`
z = jax.jit(add, donate_argnums=(1,))(x, y)
w = y + 1  # Reuses `y` whose buffer was donated above
# >> RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer

如果捐赠的缓冲区未被使用,例如因为捐赠的缓冲区数量超过了输出所需的数量,您将收到警告:

# Execute `add` with donation of the buffers for both `x` and `y`.
# One of those buffers will be used for the result, but the other will
# not be used.
z = jax.jit(add, donate_argnums=(0, 1))(x, y)
# >> UserWarning: Some donated buffers were not usable: f32[2,3]{1,0}

如果没有输出形状与捐赠匹配,捐赠也可能未被使用:

y = jax.device_put(np.ones((1, 3)))  # `y` has different shape than the output
# Execute `add` with donation of the buffer for `y`.
z = jax.jit(add, donate_argnums=(1,))(x, y)
# >> UserWarning: Some donated buffers were not usable: f32[1,3]{1,0}

梯度在使用 where 的地方包含 NaN#

如果你使用 where 定义一个函数以避免未定义的值,如果你不小心,可能会得到一个反向微分的 NaN:

def my_log(x):
  return jnp.where(x > 0., jnp.log(x), 0.)

my_log(0.) ==> 0.  # Ok
jax.grad(my_log)(0.)  ==> NaN

简而言之,在 grad 计算过程中,对应于未定义的 jnp.log(x) 的伴随项是一个 NaN,并且它会累积到 jnp.where 的伴随项中。编写此类函数的正确方法是确保在部分定义的函数 内部 有一个 jnp.where,以确保伴随项始终是有限的:

def safe_for_grad_log(x):
  return jnp.log(jnp.where(x > 0., x, 1.))

safe_for_grad_log(0.) ==> 0.  # Ok
jax.grad(safe_for_grad_log)(0.)  ==> 0.  # Ok

除了原始的 jnp.where 之外,可能还需要内部的 jnp.where,例如:

def my_log_or_y(x, y):
  """Return log(x) if x > 0 or y"""
  return jnp.where(x > 0., jnp.log(jnp.where(x > 0., x, 1.)), y)

附加阅读:

为什么基于排序顺序的函数的梯度为零?#

如果你定义了一个函数,该函数使用依赖于输入相对顺序的操作(例如 maxgreaterargsort 等)来处理输入,那么你可能会惊讶地发现梯度在所有地方都是零。这里有一个例子,我们定义 f(x) 为一个阶跃函数,当 x 为负时返回 0,当 x 为正时返回 1:

import jax
import numpy as np
import jax.numpy as jnp

def f(x):
  return (x > 0).astype(float)

df = jax.vmap(jax.grad(f))

x = jnp.array([-1.0, -0.5, 0.0, 0.5, 1.0])

print(f"f(x)  = {f(x)}")
# f(x)  = [0. 0. 0. 1. 1.]

print(f"df(x) = {df(x)}")
# df(x) = [0. 0. 0. 0. 0.]

梯度在各处都为零这一事实乍一看可能会令人困惑:毕竟,输出确实会根据输入的变化而变化,那么梯度怎么会是零呢?然而,在这种情况下,零实际上是正确的结果。

为什么是这样?记住,微分测量的是在 x 发生极小变化时 f 的变化。对于 x=1.0f 返回 1.0。如果我们扰动 x 使其稍微增大或减小,这不会改变输出,因此根据定义,grad(f)(1.0) 应该是零。这种逻辑对所有大于零的 f 值都成立:极小地扰动输入不会改变输出,所以梯度是零。同样,对于所有小于零的 x 值,输出是零。扰动 x 不会改变这个输出,所以梯度是零。这给我们留下了 x=0 这个棘手的情况。当然,如果你向上扰动 x,它会改变输出,但这有问题:x 的极小变化在函数值中产生有限的变化,这意味着梯度是未定义的。幸运的是,在这种情况下,我们还有另一种测量梯度的方法:我们向下扰动函数,在这种情况下输出不会改变,因此梯度是零。JAX 和其他自动微分系统倾向于以这种方式处理不连续性:如果正梯度和负梯度不一致,但一个是定义的而另一个不是,我们使用定义的那个。根据这种梯度定义,从数学和数值上来说,这个函数的梯度处处为零。

问题源于我们的函数在 x = 0 处有一个不连续性。这里的 f 本质上是一个 Heaviside Step Function,我们可以使用 Sigmoid Function 作为平滑的替代。当 x 远离零时,Sigmoid 函数近似等于 Heaviside 函数,但在 x = 0 处用一个平滑、可微的曲线替代了不连续性。通过使用 jax.nn.sigmoid(),我们得到了一个类似的计算,并且具有定义良好的梯度:

def g(x):
  return jax.nn.sigmoid(x)

dg = jax.vmap(jax.grad(g))

x = jnp.array([-10.0, -1.0, 0.0, 1.0, 10.0])

with np.printoptions(suppress=True, precision=2):
  print(f"g(x)  = {g(x)}")
  # g(x)  = [0.   0.27 0.5  0.73 1.  ]

  print(f"dg(x) = {dg(x)}")
  # dg(x) = [0.   0.2  0.25 0.2  0.  ]

The jax.nn 子模块也有其他常见基于秩函数的平滑版本,例如 jax.nn.softmax() 可以替代 jax.numpy.argmax() 的使用,jax.nn.soft_sign() 可以替代 jax.numpy.sign() 的使用,jax.nn.softplus()jax.nn.squareplus() 可以替代 jax.nn.relu() 的使用,等等。

如何将 JAX Tracer 转换为 NumPy 数组?#

在运行时检查转换后的 JAX 函数时,你会发现数组值被替换为 Tracer 对象:

@jax.jit
def f(x):
  print(type(x))
  return x

f(jnp.arange(5))

这将打印以下内容:

<class 'jax.interpreters.partial_eval.DynamicJaxprTracer'>

一个常见的问题是如何将这样的跟踪器转换回正常的 NumPy 数组。简而言之,无法将跟踪器转换为 NumPy 数组,因为跟踪器是具有给定形状和数据类型的*每个可能*值的抽象表示,而 numpy 数组是该抽象类的一个具体成员。有关跟踪器如何在 JAX 变换的上下文中工作的更多讨论,请参见 JIT 机制

将 Tracers 转换回数组的问题通常出现在另一个目标的背景下,与在运行时访问计算中的中间值有关。例如:

有关运行时回调及其使用示例的更多信息,请参阅 JAX 中的外部回调

为什么某些CUDA库无法加载/初始化?#

在解析动态库时,JAX 使用通常的 动态链接器搜索模式。JAX 设置 RPATH 指向 pip 安装的 NVIDIA CUDA 包的 JAX 相对位置,如果已安装则优先使用。如果 ld.so 无法在其通常的搜索路径上找到您的 CUDA 运行时库,那么您必须将这些库的路径显式包含在 LD_LIBRARY_PATH 中。确保您的 CUDA 文件可被发现的简单方法是安装 nvidia-*-cu12 pip 包,这些包包含在标准的 jax[cuda_12] 安装选项中。

偶尔,即使你确保了你的运行时库是可发现的,加载或初始化它们时仍可能出现问题。这类问题的常见原因是运行时CUDA库初始化所需的内存不足。这种情况有时发生是因为JAX会为更快的执行预先分配当前可用设备内存的较大块,偶尔导致剩余的内存不足以用于运行时CUDA库的初始化。

在运行多个 JAX 实例、与执行自身预分配的 TensorFlow 并行运行 JAX,或在 GPU 被其他进程大量使用的系统上运行 JAX 时,这种情况尤其可能发生。如有疑问,请尝试通过减少 XLA_PYTHON_CLIENT_MEM_FRACTION 从默认的 .75,或设置 XLA_PYTHON_CLIENT_PREALLOCATE=false 来减少预分配,再次运行程序。更多详情,请参阅 JAX GPU 内存分配 页面。