错误#
本页列出了在使用 JAX 时可能会遇到的一些错误,并提供了一些代表性的示例,说明如何修复这些错误。
- class jax.errors.ConcretizationTypeError(tracer, context='')#
当在需要具体值的上下文中使用 JAX Tracer 对象时,会发生此错误(有关 Tracer 的更多信息,请参阅 不同种类的 JAX 值)。在某些情况下,可以通过将问题值标记为静态来轻松修复;在其他情况下,这可能表明您的程序正在进行 JAX 的 JIT 编译模型不直接支持的操作。
示例:
- 期望静态值的地方出现了跟踪值
此错误的一个常见原因是使用了需要静态值的跟踪值。例如:
>>> from functools import partial >>> from jax import jit >>> import jax.numpy as jnp >>> @jit ... def func(x, axis): ... return x.min(axis)
>>> func(jnp.arange(4), 0) Traceback (most recent call last): ... ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: axis argument to jnp.min().
这通常可以通过将问题参数标记为静态来解决:
>>> @partial(jit, static_argnums=1) ... def func(x, axis): ... return x.min(axis) >>> func(jnp.arange(4), 0) Array(0, dtype=int32)
- 形状取决于跟踪值
当您的即时编译计算中的形状依赖于跟踪量中的值时,也可能出现此类错误。例如:
>>> @jit ... def func(x): ... return jnp.where(x < 0) >>> func(jnp.arange(4)) Traceback (most recent call last): ... ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: The error arose in jnp.nonzero.
这是一个与 JAX 的 JIT 编译模型不兼容的操作示例,该模型要求在编译时已知数组大小。这里返回的数组大小取决于 x 的内容,这样的代码无法进行 JIT 编译。
在许多情况下,可以通过修改函数中使用的逻辑来解决这个问题;例如,以下是具有类似问题的代码:
>>> @jit ... def func(x): ... indices = jnp.where(x > 1) ... return x[indices].sum() >>> func(jnp.arange(4)) Traceback (most recent call last): ... ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: The error arose in jnp.nonzero.
以下是如何在不创建动态大小的索引数组的情况下表达相同操作的方式:
>>> @jit ... def func(x): ... return jnp.where(x > 1, x, 0).sum() >>> func(jnp.arange(4)) Array(5, dtype=int32)
要了解有关跟踪器与常规值、具体值与抽象值之间更多微妙的区别,您可能需要阅读 faq-不同种类的jax值。
- 参数:
tracer (core.Tracer)
context (str)
- class jax.errors.KeyReuseError(message)#
当一个 PRNG 键以不安全的方式被重复使用时,会发生此错误。仅当 jax_debug_key_reuse 设置为 True 时,才会检查键的重复使用。
以下是一个可能导致此类错误的代码示例:
>>> with jax.debug_key_reuse(True): ... key = jax.random.key(0) ... value = jax.random.uniform(key) ... new_value = jax.random.uniform(key) ... --------------------------------------------------------------------------- KeyReuseError Traceback (most recent call last) ... KeyReuseError: Previously-consumed key passed to jit-compiled function at index 0
这种密钥重用是有问题的,因为JAX PRNG是无状态的,密钥必须手动拆分;更多信息请参见 Sharp Bits: 随机数。
- 参数:
message (str)
- class jax.errors.NonConcreteBooleanIndexError(tracer)#
当程序尝试在跟踪索引操作中使用非具体布尔索引时,会发生此错误。在JIT编译下,JAX数组必须具有静态形状(即编译时已知的形状),因此必须谨慎使用布尔掩码。通过布尔掩码实现的一些逻辑在
jax.jit()
函数中根本不可能实现;在其他情况下,逻辑可以通过重新表达为JIT兼容的方式来实现,通常使用where()
的三参数版本。以下是此错误可能出现的一些示例。
- 通过布尔掩码构造数组
这通常发生在尝试在JIT上下文中通过布尔掩码创建数组时。例如:
>>> import jax >>> import jax.numpy as jnp >>> @jax.jit ... def positive_values(x): ... return x[x > 0] >>> positive_values(jnp.arange(-5, 5)) Traceback (most recent call last): ... NonConcreteBooleanIndexError: Array boolean indices must be concrete: ShapedArray(bool[10])
此函数试图仅返回输入数组中的正值;除非将 x 标记为静态,否则无法在编译时确定此返回数组的大小,因此无法在 JIT 编译下执行此类操作。
- 可重表达的布尔逻辑
虽然不直接支持创建动态大小的数组,但在许多情况下,可以通过重新表达计算逻辑来实现与JIT兼容的操作。例如,这里还有另一个因相同原因在JIT下失败的函数:
>>> @jax.jit ... def sum_of_positive(x): ... return x[x > 0].sum() >>> sum_of_positive(jnp.arange(-5, 5)) Traceback (most recent call last): ... NonConcreteBooleanIndexError: Array boolean indices must be concrete: ShapedArray(bool[10])
然而,在这种情况下,有问题的数组只是一个中间值,我们可以用与JIT兼容的三参数版本的
jax.numpy.where()
来表达相同的逻辑:>>> @jax.jit ... def sum_of_positive(x): ... return jnp.where(x > 0, x, 0).sum() >>> sum_of_positive(jnp.arange(-5, 5)) Array(10, dtype=int32)
这种用三参数
where()
替换布尔掩码的模式是解决这类问题的常见方案。- 对 JAX 数组的布尔索引
这种错误经常出现的另一种情况是使用布尔索引时,例如使用
.at[...].set(...)
。这里有一个简单的例子:>>> @jax.jit ... def manual_clip(x): ... return x.at[x < 0].set(0) >>> manual_clip(jnp.arange(-2, 2)) Traceback (most recent call last): ... NonConcreteBooleanIndexError: Array boolean indices must be concrete: ShapedArray(bool[4])
此函数尝试将小于零的值设置为标量填充值。如上所述,这可以通过根据
where()
重新表达逻辑来解决:>>> @jax.jit ... def manual_clip(x): ... return jnp.where(x < 0, 0, x) >>> manual_clip(jnp.arange(-2, 2)) Array([0, 0, 0, 1], dtype=int32)
- 参数:
tracer (core.Tracer)
- class jax.errors.TracerArrayConversionError(tracer)#
当程序尝试将一个 JAX Tracer 对象转换为标准 NumPy 数组时,会发生此错误(有关 Tracer 的更多信息,请参阅 不同种类的 JAX 值)。它通常发生在以下几种情况之一。
- 在 JAX 变换中使用非 JAX 函数
如果你尝试在 JAX 变换(如
jit()
、grad()
、jax.vmap()
等)中使用非 JAX 库(如numpy
或scipy
),可能会出现此错误。例如:>>> from jax import jit >>> import numpy as np >>> @jit ... def func(x): ... return np.sin(x) >>> func(np.arange(4)) Traceback (most recent call last): ... TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape int32[4]
在这种情况下,你可以通过使用
jax.numpy.sin()
来代替numpy.sin()
来解决问题:>>> import jax.numpy as jnp >>> @jit ... def func(x): ... return jnp.sin(x) >>> func(jnp.arange(4)) Array([0. , 0.84147096, 0.9092974 , 0.14112 ], dtype=float32)
另请参阅 外部回调 ,了解从转换后的 JAX 代码回调主机端计算的选项。
- 使用追踪器索引 numpy 数组
如果在涉及数组索引的行上出现此错误,可能是因为被索引的数组
x
是标准的 numpy.ndarray,而索引idx
是跟踪的 JAX 数组。例如:>>> x = np.arange(10) >>> @jit ... def func(i): ... return x[i] >>> func(0) Traceback (most recent call last): ... TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape int32[0]
根据上下文,您可以通过将 numpy 数组转换为 JAX 数组来解决此问题:
>>> @jit ... def func(i): ... return jnp.asarray(x)[i] >>> func(0) Array(0, dtype=int32)
或者通过将索引声明为静态参数:
>>> from functools import partial >>> @partial(jit, static_argnums=(0,)) ... def func(i): ... return x[i] >>> func(0) Array(0, dtype=int32)
要了解有关跟踪器与常规值、具体值与抽象值之间更多微妙的区别,您可能需要阅读 faq-不同种类的jax值。
- 参数:
tracer (core.Tracer)
- class jax.errors.TracerBoolConversionError(tracer)#
当 JAX 中的跟踪值在需要布尔值的上下文中使用时,会发生此错误(有关跟踪值的更多信息,请参阅 不同种类的 JAX 值)。
布尔转换可能是显式的(例如
bool(x)
)或隐式的,通过使用控制流(例如if x > 0
或while x
),使用Python布尔运算符(例如z = x and y
,z = x or y
,z = not x
)或使用它们的函数(例如z = max(x, y)
,z = min(x, y)
等)。在某些情况下,这个问题可以通过将跟踪的值标记为静态来轻松解决;在其他情况下,这可能表明您的程序正在进行JAX的JIT编译模型不直接支持的操作。
示例:
- 在控制流中使用的跟踪值
这种情况经常出现在跟踪值在Python控制流中使用时。例如:
>>> from jax import jit >>> import jax.numpy as jnp >>> @jit ... def func(x, y): ... return x if x.sum() < y.sum() else y >>> func(jnp.ones(4), jnp.zeros(4)) Traceback (most recent call last): ... TracerBoolConversionError: Attempted boolean conversion of JAX Tracer [...]
我们可以将输入
x
和y
都标记为静态,但这会违背在此使用jax.jit()
的目的。另一个选择是用三项jax.numpy.where()
重新表达 if 语句:>>> @jit ... def func(x, y): ... return jnp.where(x.sum() < y.sum(), x, y) >>> func(jnp.ones(4), jnp.zeros(4)) Array([0., 0., 0., 0.], dtype=float32)
对于更复杂的控制流,包括循环,请参见 lax-控制流。
- 跟踪值的控制流
此错误的另一个常见原因是,如果你无意中跟踪了一个布尔标志。例如:
>>> @jit ... def func(x, normalize=True): ... if normalize: ... return x / x.sum() ... return x >>> func(jnp.arange(5), True) Traceback (most recent call last): ... TracerBoolConversionError: Attempted boolean conversion of JAX Tracer ...
这里因为
normalize
标志被追踪,所以它不能在Python控制流中使用。在这种情况下,最好的解决方案可能是将这个值标记为静态:>>> from functools import partial >>> @partial(jit, static_argnames=['normalize']) ... def func(x, normalize=True): ... if normalize: ... return x / x.sum() ... return x >>> func(jnp.arange(5), True) Array([0. , 0.1, 0.2, 0.3, 0.4], dtype=float32)
更多关于
static_argnums
的信息,请参阅jax.jit()
的文档。- 使用非 JAX 感知函数
此错误的另一个常见原因是,在 JAX 代码中使用了非 JAX 感知的函数。例如:
>>> @jit ... def func(x): ... return min(x, 0)
>>> func(2) Traceback (most recent call last): ... TracerBoolConversionError: Attempted boolean conversion of JAX Tracer ...
在这种情况下,错误发生是因为 Python 的内置
min
函数与 JAX 转换不兼容。可以通过将其替换为jnp.minimum
来解决这个问题:>>> @jit ... def func(x): ... return jnp.minimum(x, 0)
>>> print(func(2)) 0
要了解有关跟踪器与常规值、具体值与抽象值之间更多微妙的区别,您可能需要阅读 faq-不同种类的jax值。
- 参数:
tracer (core.Tracer)
- class jax.errors.TracerIntegerConversionError(tracer)#
当在期望使用Python整数的环境中使用了JAX Tracer对象时,可能会发生此错误(有关Tracer的更多信息,请参见 不同种类的 JAX 值)。这种情况通常发生在几种情况下。
- 传递一个跟踪器来代替整数
如果你尝试将一个追踪值传递给需要静态整数参数的函数,可能会发生此错误;例如:
>>> from jax import jit >>> import numpy as np >>> @jit ... def func(x, axis): ... return np.split(x, 2, axis) >>> func(np.arange(4), 0) Traceback (most recent call last): ... TracerIntegerConversionError: The __index__() method was called on traced array with shape int32[0]
当这种情况发生时,解决方案通常是将有问题的参数标记为静态:
>>> from functools import partial >>> @partial(jit, static_argnums=1) ... def func(x, axis): ... return np.split(x, 2, axis) >>> func(np.arange(10), 0) [Array([0, 1, 2, 3, 4], dtype=int32), Array([5, 6, 7, 8, 9], dtype=int32)]
另一种方法是将对一个封装了需要保护的参数的闭包应用转换,可以手动操作如下,或者使用
functools.partial()
:>>> jit(lambda arr: np.split(arr, 2, 0))(np.arange(4)) [Array([0, 1], dtype=int32), Array([2, 3], dtype=int32)]
注意每次调用都会创建一个新的闭包,这会破坏编译缓存机制,这就是为什么首选使用 static_argnums。
- 使用 Tracer 索引列表
如果你尝试用一个追踪的数量来索引一个Python列表,可能会发生这个错误。例如:
>>> import jax.numpy as jnp >>> from jax import jit >>> L = [1, 2, 3] >>> @jit ... def func(i): ... return L[i] >>> func(0) Traceback (most recent call last): ... TracerIntegerConversionError: The __index__() method was called on traced array with shape int32[0]
根据上下文,你通常可以通过将列表转换为 JAX 数组来解决这个问题:
>>> @jit ... def func(i): ... return jnp.array(L)[i] >>> func(0) Array(1, dtype=int32)
或者通过将索引声明为静态参数:
>>> from functools import partial >>> @partial(jit, static_argnums=0) ... def func(i): ... return L[i] >>> func(0) Array(1, dtype=int32, weak_type=True)
要了解有关跟踪器与常规值、具体值与抽象值之间更多微妙的区别,您可能需要阅读 faq-不同种类的jax值。
- 参数:
tracer (core.Tracer)
- class jax.errors.UnexpectedTracerError(msg)#
当你使用一个从函数中泄露出来的 JAX 值时,会发生此错误。泄露值是什么意思?如果你对一个函数
f
使用 JAX 变换,该函数在f
之外的某个作用域中存储了对中间值的引用,那么该值就被认为是泄露的。泄露值是一种副作用。(了解更多关于避免副作用的信息,请阅读 纯函数)JAX 在你随后在另一个操作中使用泄漏的值时检测到泄漏,此时它会引发一个
UnexpectedTracerError
。要修复这个问题,请避免副作用:如果一个函数计算出一个在外部作用域中需要的值,请从转换后的函数中显式返回该值。具体来说,
Tracer
是 JAX 在转换过程中对函数中间值的内部表示,例如在jit()
、pmap()
、vmap()
等内部。在转换外部遇到Tracer
意味着发生了泄漏。- 泄漏值的生命周期
考虑以下转换函数的一个例子,该函数将一个值泄漏到外部作用域:
>>> from jax import jit >>> import jax.numpy as jnp >>> outs = [] >>> @jit # 1 ... def side_effecting(x): ... y = x + 1 # 3 ... outs.append(y) # 4 >>> x = 1 >>> side_effecting(x) # 2 >>> outs[0] + 1 # 5 Traceback (most recent call last): ... UnexpectedTracerError: Encountered an unexpected tracer.
在这个例子中,我们将一个跟踪值从内部转换的作用域泄漏到外部作用域。当泄漏的值被使用时,我们得到一个
UnexpectedTracerError
,而不是在值泄漏时。这个例子还展示了泄漏值的生命周期:
一个函数被转换(在这种情况下,通过
jit()
)转换后的函数被调用(启动函数的抽象跟踪并将
x
转换为Tracer
)中间值
y
被创建(跟踪函数的中间值也是一个Tracer
),稍后将被泄露。值被泄露(附加到外部作用域中的列表,通过侧通道逃逸函数)
使用了泄漏的值,并引发了一个 UnexpectedTracerError。
UnexpectedTracerError 消息试图通过包含每个阶段的信息来指向代码中的这些位置。分别是:
转换后的函数名称 (
side_effecting
) 以及触发跟踪的转换jit()
)。一个重构的堆栈跟踪,显示了泄漏的 Tracer 是在哪里创建的,其中包括转换后的函数被调用的地方。(
当 Tracer 被创建时,最后的 5 个堆栈帧是...
)。从重建的堆栈跟踪中,创建了泄漏的 Tracer 的代码行。
泄漏位置未包含在错误消息中,因为很难确定!JAX只能告诉你泄漏值的外观(它的形状以及它是在哪里创建的)以及它泄漏的边界(转换的名称和转换后的函数的名称)。
当前错误的堆栈跟踪指向值被使用的地方。
可以通过从转换函数中返回值来修复错误:
>>> from jax import jit >>> import jax.numpy as jnp >>> outs = [] >>> @jit ... def not_side_effecting(x): ... y = x+1 ... return y >>> x = 1 >>> y = not_side_effecting(x) >>> outs.append(y) >>> outs[0] + 1 # all good! no longer a leaked value. Array(3, dtype=int32, weak_type=True)
- 泄漏检查器
如上文第2点和第3点所述,JAX显示了一个重建的堆栈跟踪,该跟踪指向泄漏值的创建位置。这是因为JAX仅在泄漏值被使用时才引发错误,而不是在值泄漏时。这不是引发此错误的最有用位置,因为您需要知道Tracer泄漏的位置来修复错误。
为了更容易追踪这个位置,你可以使用泄漏检查器。当泄漏检查器启用时,一旦
Tracer
泄漏,就会引发错误。(更准确地说,当泄漏Tracer
的转换函数返回时,它将引发错误)要启用泄漏检查器,您可以使用
JAX_CHECK_TRACER_LEAKS
环境变量或with jax.checking_leaks()
上下文管理器。备注
请注意,此工具是实验性的,可能会报告误报。它通过禁用一些 JAX 缓存来工作,因此会对性能产生负面影响,仅应在调试时使用。
示例用法:
>>> from jax import jit >>> import jax.numpy as jnp >>> outs = [] >>> @jit ... def side_effecting(x): ... y = x+1 ... outs.append(y) >>> x = 1 >>> with jax.checking_leaks(): ... y = side_effecting(x) Traceback (most recent call last): ... Exception: Leaked Trace
- 参数:
msg (str)