常见问题解答 (FAQ)#
我们正在收集常见问题的答案。欢迎贡献!
jit
改变了我的函数的行为#
如果你有一个Python函数在使用 jax.jit()
后改变了行为,可能你的函数使用了全局状态,或者有副作用。在下面的代码中,impure_func
使用了全局变量 y
,并且由于 print
产生了副作用:
y = 0
# @jit # Different behavior with jit
def impure_func(x):
print("Inside:", y)
return x + y
for y in range(3):
print("Result:", impure_func(y))
没有 jit
时,输出为:
Inside: 0
Result: 0
Inside: 1
Result: 2
Inside: 2
Result: 4
并且使用 jit
时,它是:
Inside: 0
Result: 0
Result: 1
Result: 2
对于 jax.jit()
,函数首先使用 Python 解释器执行一次,此时会打印 Inside
,并观察到 y
的第一个值。然后,函数被编译并缓存,之后多次执行,输入不同的 x
值,但 y
的第一个值保持不变。
附加阅读:
jit
改变了输出的精确数值#
有时用户会对用 jit()
包装一个函数可以改变函数的输出这一事实感到惊讶。例如:
>>> from jax import jit
>>> import jax.numpy as jnp
>>> def f(x):
... return jnp.log(jnp.sqrt(x))
>>> x = jnp.pi
>>> print(f(x))
0.572365
>>> print(jit(f)(x))
0.5723649
这种输出上的细微差异来自于XLA编译器内部的优化:在编译过程中,XLA有时会重新排列或省略某些操作,以使整体计算更加高效。
在这种情况下,XLA 利用对数的性质,将 log(sqrt(x))
替换为 0.5 * log(x)
,这是一个在数学上等价的表达式,但计算效率比原表达式更高。输出结果的差异源于浮点运算只是实数运算的一个近似,因此计算同一表达式的不同方式可能会产生微妙的差异。
其他时候,XLA 的优化可能会导致更加显著的差异。考虑以下示例:
>>> def f(x):
... return jnp.log(jnp.exp(x))
>>> x = 100.0
>>> print(f(x))
inf
>>> print(jit(f)(x))
100.0
在非JIT编译的逐操作模式下,结果是 inf
因为 jnp.exp(x)
溢出并返回 inf
。然而,在JIT下,XLA识别到 log
是 exp
的逆运算,并从编译函数中移除了这些操作,直接返回输入。在这种情况下,JIT编译产生了一个更准确的实数结果的浮点近似值。
不幸的是,XLA 的代数简化完整列表没有得到很好的文档化,但如果你熟悉 C++ 并且对 XLA 编译器所做的优化类型感到好奇,你可以在源代码中查看它们:algebraic_simplifier.cc。
jit
装饰的函数编译速度非常慢#
如果你的 jit
装饰函数在第一次调用时需要几十秒(或更长时间!)来运行,但在再次调用时执行速度很快,那么 JAX 正在花费大量时间来追踪或编译你的代码。
这通常是一个迹象,表明调用你的函数在JAX的内部表示中生成了大量的代码,通常是因为它大量使用了Python控制流,如 for
循环。对于少量的循环迭代,Python是可以的,但如果你需要 许多 循环迭代,你应该重写你的代码以利用JAX的 结构化控制流原语 (如 lax.scan()
),或者避免用 jit
包装循环(你仍然可以在循环 内部 使用 jit
装饰的函数)。
如果你不确定这是否是问题所在,你可以尝试在你的函数上运行 jax.make_jaxpr()
。如果输出长达数百或数千行,你可以预期编译速度会变慢。
有时,如何重写代码以避免Python循环并不明显,因为您的代码使用了多种不同形状的数组。在这种情况下,推荐的解决方案是使用 jax.numpy.where()
等函数,在具有固定形状的填充数组上进行计算。
如果你的函数由于其他原因编译速度慢,请在GitHub上提交一个问题。
如何将 jit
用于方法?#
大多数 jax.jit()
的例子都是关于装饰独立的 Python 函数的,但装饰类中的方法会引入一些复杂性。例如,考虑以下简单的类,我们在方法上使用了标准的 jit()
注解:
>>> import jax.numpy as jnp
>>> from jax import jit
>>> class CustomClass:
... def __init__(self, x: jnp.ndarray, mul: bool):
... self.x = x
... self.mul = mul
...
... @jit # <---- How to do this correctly?
... def calc(self, y):
... if self.mul:
... return self.x * y
... return y
然而,当你尝试调用此方法时,这种方法将导致错误:
>>> c = CustomClass(2, True)
>>> c.calc(3)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
File "<stdin>", line 1, in <module
TypeError: Argument '<CustomClass object at 0x7f7dd4125890>' of type <class 'CustomClass'> is not a valid JAX type.
问题在于函数的第一个参数是 self
,其类型为 CustomClass
,而 JAX 不知道如何处理这种类型。在这种情况下,我们可能使用三种基本策略,下面将讨论它们。
策略 1: JIT 编译辅助函数#
最直接的方法是创建一个类外部的辅助函数,它可以以正常方式进行JIT装饰。例如:
>>> from functools import partial
>>> class CustomClass:
... def __init__(self, x: jnp.ndarray, mul: bool):
... self.x = x
... self.mul = mul
...
... def calc(self, y):
... return _calc(self.mul, self.x, y)
>>> @partial(jit, static_argnums=0)
... def _calc(mul, x, y):
... if mul:
... return x * y
... return y
结果将按预期工作:
>>> c = CustomClass(2, True)
>>> print(c.calc(3))
6
这种方法的好处是它简单、明确,并且避免了教 JAX 如何处理 CustomClass
类型的对象。然而,你可能希望将所有方法逻辑保持在同一个地方。
策略 2:将 self
标记为静态#
另一种常见模式是使用 static_argnums
将 self
参数标记为静态。但必须小心操作以避免意外结果。你可能会倾向于简单地这样做:
>>> class CustomClass:
... def __init__(self, x: jnp.ndarray, mul: bool):
... self.x = x
... self.mul = mul
...
... # WARNING: this example is broken, as we'll see below. Don't copy & paste!
... @partial(jit, static_argnums=0)
... def calc(self, y):
... if self.mul:
... return self.x * y
... return y
如果你调用这个方法,它将不再引发错误:
>>> c = CustomClass(2, True)
>>> print(c.calc(3))
6
然而,有一个问题:如果你在第一次方法调用后修改了对象,后续的方法调用可能会返回错误的结果:
>>> c.mul = False
>>> print(c.calc(3)) # Should print 3
6
为什么是这样?当你将一个对象标记为静态时,它将被有效地用作JIT内部编译缓存中的字典键,这意味着其哈希值(即 hash(obj)
)相等性(即 obj1 == obj2
)和对象标识(即 obj1 is obj2
)将被假定为具有一致的行为。自定义对象的默认 __hash__
是其对象ID,因此JAX无法知道一个被修改的对象应该触发重新编译。
你可以通过为你的对象定义适当的 __hash__
和 __eq__
方法来部分解决这个问题;例如:
>>> class CustomClass:
... def __init__(self, x: jnp.ndarray, mul: bool):
... self.x = x
... self.mul = mul
...
... @partial(jit, static_argnums=0)
... def calc(self, y):
... if self.mul:
... return self.x * y
... return y
...
... def __hash__(self):
... return hash((self.x, self.mul))
...
... def __eq__(self, other):
... return (isinstance(other, CustomClass) and
... (self.x, self.mul) == (other.x, other.mul))
(参见 object.__hash__()
文档以获取更多关于重写 __hash__
时要求的讨论).
这应该能与JIT和其他转换正确工作 只要你从不改变你的对象。作为哈希键的对象的变异会导致几个微妙的问题,这就是为什么例如可变的Python容器(如 dict
, list
)不定义 __hash__
,而它们的不可变对应物(如 tuple
)则定义。
如果你的类依赖于就地突变(例如在其方法中设置 self.attr = ...
),那么你的对象实际上并不是“静态”的,将其标记为静态可能会导致问题。幸运的是,这种情况还有另一种选择。
策略 3:将 CustomClass
设为 PyTree#
最灵活的正确JIT编译类方法的方法是将类型注册为自定义PyTree对象;参见 扩展 pytrees 。这使您能够精确指定类的哪些部分应被视为静态,哪些部分应被视为动态。以下是它的可能样子:
>>> class CustomClass:
... def __init__(self, x: jnp.ndarray, mul: bool):
... self.x = x
... self.mul = mul
...
... @jit
... def calc(self, y):
... if self.mul:
... return self.x * y
... return y
...
... def _tree_flatten(self):
... children = (self.x,) # arrays / dynamic values
... aux_data = {'mul': self.mul} # static values
... return (children, aux_data)
...
... @classmethod
... def _tree_unflatten(cls, aux_data, children):
... return cls(*children, **aux_data)
>>> from jax import tree_util
>>> tree_util.register_pytree_node(CustomClass,
... CustomClass._tree_flatten,
... CustomClass._tree_unflatten)
这当然更复杂,但它解决了上述简单方法所关联的所有问题:
>>> c = CustomClass(2, True)
>>> print(c.calc(3))
6
>>> c.mul = False # mutation is detected
>>> print(c.calc(3))
3
>>> c = CustomClass(jnp.array(2), True) # non-hashable x is supported
>>> print(c.calc(3))
6
只要你的 tree_flatten
和 tree_unflatten
函数正确处理了类中的所有相关属性,你应该能够直接将此类型的对象用作 JIT 编译函数的参数,而无需任何特殊注解。
控制数据和计算在设备上的放置#
首先,我们来看一下JAX中数据和计算放置的原理。
在 JAX 中,计算遵循数据放置。JAX 数组有两个放置属性:1) 数据所在的设备;2) 数据是否**已提交**到设备(数据有时被称为对设备*粘滞*)。
默认情况下,JAX 数组未提交到默认设备(jax.devices()[0]
),该设备默认是第一个 GPU 或 TPU。如果没有 GPU 或 TPU,jax.devices()[0]
则是 CPU。可以使用 jax.default_device()
上下文管理器临时覆盖默认设备,或者通过设置环境变量 JAX_PLATFORMS
或 absl 标志 --jax_platforms
为 “cpu”、”gpu” 或 “tpu” 来为整个进程设置默认设备(JAX_PLATFORMS
也可以是一个平台列表,按优先顺序确定哪些平台可用)。
>>> from jax import numpy as jnp
>>> print(jnp.ones(3).devices())
{CudaDevice(id=0)}
涉及未提交数据的计算在默认设备上执行,结果在默认设备上未提交。
数据也可以使用带有 device
参数的 jax.device_put()
显式放置在设备上,在这种情况下,数据将**提交**到设备:
>>> import jax
>>> from jax import device_put
>>> arr = device_put(1, jax.devices()[2])
>>> print(arr.devices())
{CudaDevice(id=2)}
涉及某些已提交输入的计算将在已提交的设备上进行,结果也将提交到同一设备。如果在已提交到多个设备的参数上调用操作,将引发错误。
你也可以在没有 device
参数的情况下使用 jax.device_put()
。如果数据已经在一个设备上(已提交或未提交),则保持原样。如果数据不在任何设备上——即它是一个常规的 Python 或 NumPy 值——它将被放置在默认设备上,且未提交。
Jitted 函数的行为类似于任何其他基本操作——它们将遵循数据,如果在多个设备上提交的数据上调用,将显示错误。
(在2021年3月 PR #6002 之前,数组常量的创建存在一些惰性,因此 jax.device_put(jnp.zeros(...), jax.devices()[1])
或类似的操作实际上会在 jax.devices()[1]
上创建零数组,而不是在默认设备上创建数组然后移动它。但为了简化实现,这个优化被移除了。)
(截至2020年4月,jax.jit()
有一个 device 参数,该参数影响设备分配。该参数是实验性的,可能会被移除或更改,不建议使用。)
对于一个详细的示例,我们建议阅读 multi_device_test.py 中的 test_computation_follows_data
。
基准测试 JAX 代码#
你刚刚将一个复杂的函数从 NumPy/SciPy 移植到了 JAX。这真的加快了速度吗?
在使用 JAX 测量代码速度时,请记住与 NumPy 相比的这些重要差异:
JAX 代码是即时(JIT)编译的。 大多数用 JAX 编写的代码可以以支持 JIT 编译的方式编写,这可以使它运行得 快得多 (参见 To JIT or not to JIT)。为了从 JAX 中获得最大性能,你应该在你的最外层函数调用上应用
jax.jit()
。请记住,第一次运行 JAX 代码时,它会较慢,因为它正在被编译。即使你不在自己的代码中使用
jit
,这也是正确的,因为 JAX 的内置函数也是即时编译的。JAX 具有异步调度。 这意味着你需要调用
.block_until_ready()
来确保计算实际上已经发生(参见 异步分发)。JAX 默认只使用 32 位数据类型。 你可能希望在 NumPy 中显式使用 32 位数据类型,或在 JAX 中启用 64 位数据类型(参见 双精度(64 位))以进行公平比较。
在CPU和加速器之间传输数据需要时间。 如果你只想测量评估一个函数所需的时间,你可能希望先将数据传输到你想要运行它的设备上(参见 控制数据和计算在设备上的放置)。
以下是如何将所有这些技巧组合成一个微基准测试的示例,用于比较 JAX 和 NumPy,利用 IPython 的便捷 %time 和 %timeit magics:
import numpy as np
import jax.numpy as jnp
import jax
def f(x): # function we're benchmarking (works in both NumPy & JAX)
return x.T @ (x - x.mean(axis=0))
x_np = np.ones((1000, 1000), dtype=np.float32) # same as JAX default dtype
%timeit f(x_np) # measure NumPy runtime
%time x_jax = jax.device_put(x_np) # measure JAX device transfer time
f_jit = jax.jit(f)
%time f_jit(x_jax).block_until_ready() # measure JAX compilation time
%timeit f_jit(x_jax).block_until_ready() # measure JAX runtime
当在 Colab 中使用 GPU 运行时,我们看到:
在CPU上,NumPy每次评估需要16.2毫秒
JAX 需要 1.26 毫秒将 NumPy 数组复制到 GPU 上
JAX 需要 193 毫秒来编译函数
在GPU上,JAX每次评估需要485微秒
在这种情况下,我们看到一旦数据被传输并且函数被编译,JAX 在 GPU 上的重复评估速度大约快了 30 倍。
这是一个公平的比较吗?也许吧。最终重要的是运行完整应用程序的性能,这不可避免地包括一定量的数据传输和编译。此外,我们小心地选择了足够大的数组(1000x1000)和足够密集的计算(@
运算符正在进行矩阵-矩阵乘法),以摊销JAX/加速器与NumPy/CPU增加的开销。例如,如果我们把这个例子改为使用10x10的输入,JAX/GPU的运行速度比NumPy/CPU慢10倍(100 µs vs 10 µs)。
JAX 比 NumPy 更快吗?#
用户经常试图通过此类基准测试来回答的一个问题是 JAX 是否比 NumPy 更快;由于这两个包的差异,没有简单的答案。
广义上讲:
NumPy 操作是即时执行的,同步的,并且仅在 CPU 上执行。
JAX 操作可以在编译后执行(如果在
jit()
内部),也可以立即执行;它们是异步分派的(参见 异步分发);并且它们可以在 CPU、GPU 或 TPU 上执行,每种设备都有截然不同且不断演变的性能特征。
这些架构上的差异使得在 NumPy 和 JAX 之间进行有意义的直接基准比较变得困难。
此外,这些差异导致了不同软件包之间的工程重点不同:例如,NumPy在减少单个数组操作的每次调用分派开销方面投入了大量精力,因为在NumPy的计算模型中,这种开销是无法避免的。而JAX则有几种避免分派开销的方法(例如JIT编译、异步分派、批处理转换等),因此减少每次调用的开销并不是优先考虑的事项。
记住这些,总结来说:如果你在CPU上对单个数组操作进行微基准测试,通常可以预期NumPy由于其较低的每次操作调度开销而优于JAX。如果你在GPU或TPU上运行代码,或者在CPU上对更复杂的JIT编译操作序列进行基准测试,通常可以预期JAX优于NumPy。
不同种类的 JAX 值#
在函数转换过程中,JAX 将一些函数参数替换为特殊的跟踪值。
如果你使用 print
语句,你可以看到这个:
def func(x):
print(x)
return jnp.cos(x)
res = jax.jit(func)(0.)
上述代码确实返回了正确的值 1.
,但它也为 x
的值打印了 Traced<ShapedArray(float32[])>
。通常情况下,JAX 以透明的方式在内部处理这些跟踪值,例如,在用于实现 jax.numpy
函数的数值 JAX 原语中。这就是为什么 jnp.cos
在上面的例子中可以工作。
更准确地说,对于JAX转换函数的参数引入了一个 tracer 值,除了那些由特殊参数标识的参数,例如 jax.jit()
的 static_argnums
或 jax.pmap()
的 static_broadcasted_argnums
。通常,涉及至少一个tracer值的计算将产生一个tracer值。除了tracer值之外,还有 regular Python值:这些值是在JAX转换之外计算的,或者来自上述某些JAX转换的静态参数,或者仅从其他regular Python值计算得出。在没有JAX转换的情况下,这些值在各处使用。
一个跟踪值携带一个 抽象 值,例如,ShapedArray
包含数组的形状和数据类型的信息。我们在这里将这些跟踪值称为 抽象跟踪值。一些跟踪值,例如,那些为自动微分变换的参数引入的跟踪值,携带 ConcreteArray
抽象值,这些抽象值实际上包含常规数组数据,并用于解析条件等。我们在这里将这些跟踪值称为 具体跟踪值。从这些具体跟踪值计算出的跟踪值,可能与常规值结合,结果是具体跟踪值。一个 具体值 要么是一个常规值,要么是一个具体跟踪值。
大多数情况下,从跟踪器值计算出的值本身也是跟踪器值。只有极少数例外,当计算可以完全使用跟踪器携带的抽象值完成时,结果可以是常规值。例如,获取具有 ShapedArray
抽象值的跟踪器的形状。另一个例子是当显式地将具体跟踪器值转换为常规类型时,例如 int(x)
或 x.astype(float)
。另一种情况是 bool(x)
,当具体性使其可能时,它会生成一个 Python bool。这种情况尤为突出,因为它在控制流中经常出现。
以下是如何通过转换引入抽象或具体的示踪剂:
jax.jit()
为所有位置参数引入 抽象追踪器,除了那些由static_argnums
表示的参数,它们保持常规值。jax.pmap()
: 为所有位置参数引入**抽象追踪器**,除了那些由static_broadcasted_argnums
表示的参数。jax.vmap()
,jax.make_jaxpr()
,xla_computation()
: 为所有位置参数引入 抽象追踪器。jax.jvp()
和jax.grad()
为所有位置参数引入 具体跟踪器。当这些变换位于外部变换内且实际参数本身是抽象跟踪器时,则自动微分变换引入的跟踪器也是抽象跟踪器。所有高阶控制流原语(
lax.cond()
、lax.while_loop()
、lax.fori_loop()
、lax.scan()
)在处理函数时都会引入 抽象追踪器,无论是否正在进行 JAX 变换。
当你有只能操作常规Python值的代码时,所有这些都是相关的,例如基于数据的控制流条件代码:
def divide(x, y):
return x / y if y >= 1. else 0.
如果我们想应用 jax.jit()
,我们必须确保指定 static_argnums=1
以确保 y
保持为一个常规值。这是由于布尔表达式 y >= 1.
,它需要具体值(常规值或跟踪器)。如果我们显式地写 bool(y >= 1.)
,或者 int(y)
,或者 float(y)
,同样的情况也会发生。
有趣的是,jax.grad(divide)(3., 2.)
可以工作,因为 jax.grad()
使用了具体的追踪器,并使用 y
的具体值来解析条件。
Buffer donation#
当 JAX 执行计算时,它会在设备上为所有输入和输出使用缓冲区。如果你知道某个输入在计算后不再需要,并且它的形状和元素类型与某个输出匹配,你可以指定希望将相应的输入缓冲区捐赠以保存输出。这将通过捐赠缓冲区的大小来减少执行所需的内存。
如果你有类似以下的模式,你可以使用缓冲区捐赠:
params, state = jax.pmap(update_fn, donate_argnums=(0, 1))(params, state)
你可以把这看作是对你的不可变 JAX 数组进行内存高效的功能性更新的方法。在计算的边界内,XLA 可以为你进行这种优化,但在 jit/pmap 边界处,你需要向 XLA 保证在调用捐赠函数后不会使用捐赠的输入缓冲区。
您可以通过在函数 jax.jit()
、jax.pjit()
和 jax.pmap()
中使用 donate_argnums 参数来实现这一点。此参数是一个索引序列(基于0),指向位置参数列表:
def add(x, y):
return x + y
x = jax.device_put(np.ones((2, 3)))
y = jax.device_put(np.ones((2, 3)))
# Execute `add` with donation of the buffer for `y`. The result has
# the same shape and type as `y`, so it will share its buffer.
z = jax.jit(add, donate_argnums=(1,))(x, y)
请注意,当前在使用关键字参数调用您的函数时,这不会生效!以下代码不会捐赠任何缓冲区:
params, state = jax.pmap(update_fn, donate_argnums=(0, 1))(params=params, state=state)
如果一个参数的缓冲区被捐赠,并且它是一个 pytree,那么它的所有组件的缓冲区都会被捐赠:
def add_ones(xs: List[Array]):
return [x + 1 for x in xs]
xs = [jax.device_put(np.ones((2, 3))), jax.device_put(np.ones((3, 4)))]
# Execute `add_ones` with donation of all the buffers for `xs`.
# The outputs have the same shape and type as the elements of `xs`,
# so they will share those buffers.
z = jax.jit(add_ones, donate_argnums=0)(xs)
不允许捐赠一个随后在计算中使用的缓冲区,JAX 会报错,因为 y 的缓冲区在捐赠后变得无效:
# Donate the buffer for `y`
z = jax.jit(add, donate_argnums=(1,))(x, y)
w = y + 1 # Reuses `y` whose buffer was donated above
# >> RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer
如果捐赠的缓冲区未被使用,例如因为捐赠的缓冲区数量超过了输出所需的数量,您将收到警告:
# Execute `add` with donation of the buffers for both `x` and `y`.
# One of those buffers will be used for the result, but the other will
# not be used.
z = jax.jit(add, donate_argnums=(0, 1))(x, y)
# >> UserWarning: Some donated buffers were not usable: f32[2,3]{1,0}
如果没有输出形状与捐赠匹配,捐赠也可能未被使用:
y = jax.device_put(np.ones((1, 3))) # `y` has different shape than the output
# Execute `add` with donation of the buffer for `y`.
z = jax.jit(add, donate_argnums=(1,))(x, y)
# >> UserWarning: Some donated buffers were not usable: f32[1,3]{1,0}
梯度在使用 where
的地方包含 NaN#
如果你使用 where
定义一个函数以避免未定义的值,如果你不小心,可能会得到一个反向微分的 NaN
:
def my_log(x):
return jnp.where(x > 0., jnp.log(x), 0.)
my_log(0.) ==> 0. # Ok
jax.grad(my_log)(0.) ==> NaN
简而言之,在 grad
计算过程中,对应于未定义的 jnp.log(x)
的伴随项是一个 NaN
,并且它会累积到 jnp.where
的伴随项中。编写此类函数的正确方法是确保在部分定义的函数 内部 有一个 jnp.where
,以确保伴随项始终是有限的:
def safe_for_grad_log(x):
return jnp.log(jnp.where(x > 0., x, 1.))
safe_for_grad_log(0.) ==> 0. # Ok
jax.grad(safe_for_grad_log)(0.) ==> 0. # Ok
除了原始的 jnp.where
之外,可能还需要内部的 jnp.where
,例如:
def my_log_or_y(x, y):
"""Return log(x) if x > 0 or y"""
return jnp.where(x > 0., jnp.log(jnp.where(x > 0., x, 1.)), y)
附加阅读:
为什么基于排序顺序的函数的梯度为零?#
如果你定义了一个函数,该函数使用依赖于输入相对顺序的操作(例如 max
、greater
、argsort
等)来处理输入,那么你可能会惊讶地发现梯度在所有地方都是零。这里有一个例子,我们定义 f(x) 为一个阶跃函数,当 x 为负时返回 0,当 x 为正时返回 1:
import jax
import numpy as np
import jax.numpy as jnp
def f(x):
return (x > 0).astype(float)
df = jax.vmap(jax.grad(f))
x = jnp.array([-1.0, -0.5, 0.0, 0.5, 1.0])
print(f"f(x) = {f(x)}")
# f(x) = [0. 0. 0. 1. 1.]
print(f"df(x) = {df(x)}")
# df(x) = [0. 0. 0. 0. 0.]
梯度在各处都为零这一事实乍一看可能会令人困惑:毕竟,输出确实会根据输入的变化而变化,那么梯度怎么会是零呢?然而,在这种情况下,零实际上是正确的结果。
为什么是这样?记住,微分测量的是在 x
发生极小变化时 f
的变化。对于 x=1.0
,f
返回 1.0
。如果我们扰动 x
使其稍微增大或减小,这不会改变输出,因此根据定义,grad(f)(1.0)
应该是零。这种逻辑对所有大于零的 f
值都成立:极小地扰动输入不会改变输出,所以梯度是零。同样,对于所有小于零的 x
值,输出是零。扰动 x
不会改变这个输出,所以梯度是零。这给我们留下了 x=0
这个棘手的情况。当然,如果你向上扰动 x
,它会改变输出,但这有问题:x
的极小变化在函数值中产生有限的变化,这意味着梯度是未定义的。幸运的是,在这种情况下,我们还有另一种测量梯度的方法:我们向下扰动函数,在这种情况下输出不会改变,因此梯度是零。JAX 和其他自动微分系统倾向于以这种方式处理不连续性:如果正梯度和负梯度不一致,但一个是定义的而另一个不是,我们使用定义的那个。根据这种梯度定义,从数学和数值上来说,这个函数的梯度处处为零。
问题源于我们的函数在 x = 0
处有一个不连续性。这里的 f
本质上是一个 Heaviside Step Function,我们可以使用 Sigmoid Function 作为平滑的替代。当 x 远离零时,Sigmoid 函数近似等于 Heaviside 函数,但在 x = 0
处用一个平滑、可微的曲线替代了不连续性。通过使用 jax.nn.sigmoid()
,我们得到了一个类似的计算,并且具有定义良好的梯度:
def g(x):
return jax.nn.sigmoid(x)
dg = jax.vmap(jax.grad(g))
x = jnp.array([-10.0, -1.0, 0.0, 1.0, 10.0])
with np.printoptions(suppress=True, precision=2):
print(f"g(x) = {g(x)}")
# g(x) = [0. 0.27 0.5 0.73 1. ]
print(f"dg(x) = {dg(x)}")
# dg(x) = [0. 0.2 0.25 0.2 0. ]
The jax.nn
子模块也有其他常见基于秩函数的平滑版本,例如 jax.nn.softmax()
可以替代 jax.numpy.argmax()
的使用,jax.nn.soft_sign()
可以替代 jax.numpy.sign()
的使用,jax.nn.softplus()
或 jax.nn.squareplus()
可以替代 jax.nn.relu()
的使用,等等。
如何将 JAX Tracer 转换为 NumPy 数组?#
在运行时检查转换后的 JAX 函数时,你会发现数组值被替换为 Tracer
对象:
@jax.jit
def f(x):
print(type(x))
return x
f(jnp.arange(5))
这将打印以下内容:
<class 'jax.interpreters.partial_eval.DynamicJaxprTracer'>
一个常见的问题是如何将这样的跟踪器转换回正常的 NumPy 数组。简而言之,无法将跟踪器转换为 NumPy 数组,因为跟踪器是具有给定形状和数据类型的*每个可能*值的抽象表示,而 numpy 数组是该抽象类的一个具体成员。有关跟踪器如何在 JAX 变换的上下文中工作的更多讨论,请参见 JIT 机制。
将 Tracers 转换回数组的问题通常出现在另一个目标的背景下,与在运行时访问计算中的中间值有关。例如:
如果你希望在运行时打印一个跟踪的值以进行调试,你可以考虑使用
jax.debug.print()
。如果你想在转换后的 JAX 函数中调用非 JAX 代码,你可以考虑使用
jax.pure_callback()
,其示例可在 Pure callback example 中找到。如果你想在运行时输入或输出数组缓冲区(例如,从文件加载数据,或将数组内容记录到磁盘),可以考虑使用
jax.experimental.io_callback()
,其示例可以在 IO 回调示例 中找到。
有关运行时回调及其使用示例的更多信息,请参阅 JAX 中的外部回调。
为什么某些CUDA库无法加载/初始化?#
在解析动态库时,JAX 使用通常的 动态链接器搜索模式。JAX 设置 RPATH
指向 pip 安装的 NVIDIA CUDA 包的 JAX 相对位置,如果已安装则优先使用。如果 ld.so
无法在其通常的搜索路径上找到您的 CUDA 运行时库,那么您必须将这些库的路径显式包含在 LD_LIBRARY_PATH
中。确保您的 CUDA 文件可被发现的简单方法是安装 nvidia-*-cu12
pip 包,这些包包含在标准的 jax[cuda_12]
安装选项中。
偶尔,即使你确保了你的运行时库是可发现的,加载或初始化它们时仍可能出现问题。这类问题的常见原因是运行时CUDA库初始化所需的内存不足。这种情况有时发生是因为JAX会为更快的执行预先分配当前可用设备内存的较大块,偶尔导致剩余的内存不足以用于运行时CUDA库的初始化。
在运行多个 JAX 实例、与执行自身预分配的 TensorFlow 并行运行 JAX,或在 GPU 被其他进程大量使用的系统上运行 JAX 时,这种情况尤其可能发生。如有疑问,请尝试通过减少 XLA_PYTHON_CLIENT_MEM_FRACTION
从默认的 .75
,或设置 XLA_PYTHON_CLIENT_PREALLOCATE=false
来减少预分配,再次运行程序。更多详情,请参阅 JAX GPU 内存分配 页面。