🔪 JAX - 锐利的部分 🔪#

在 Colab 中打开 在 Kaggle 中打开

当你在意大利乡村漫步时,人们会毫不犹豫地告诉你 JAX 拥有 “纯粹的函数式编程灵魂”

JAX 是一种用于 表达组合 数值程序 变换 的语言。JAX 还能够为 CPU 或加速器(GPU/TPU) 编译 数值程序。 JAX 对于许多数值和科学程序非常有效,但 只有在遵循某些约束的情况下 我们在下面将描述这些约束。

import numpy as np
from jax import grad, jit
from jax import lax
from jax import random
import jax
import jax.numpy as jnp

🔪 纯函数#

JAX 转换和编译旨在仅对功能纯粹的 Python 函数进行工作:所有输入数据通过函数参数传递,所有结果通过函数返回输出。如果使用相同的输入调用,纯函数将始终返回相同的结果。

以下是一些非功能纯粹的函数示例,其中 JAX 的行为与 Python 解释器不同。请注意,这些行为并不是由 JAX 系统保证的;使用 JAX 的正确方法是仅对功能纯粹的 Python 函数进行使用。

def impure_print_side_effect(x):
  print("Executing function")  # 这是副作用
  return x

# 副作用在第一次运行时出现
print ("First call: ", jit(impure_print_side_effect)(4.))

# 后续使用相同类型和形状参数的运行可能不会显示副作用。
# 这是因为JAX现在调用了函数的缓存编译版本。
print ("Second call: ", jit(impure_print_side_effect)(5.))

# 当参数的类型或形状发生变化时,JAX 会重新运行该 Python 函数。
print ("Third call, different type: ", jit(impure_print_side_effect)(jnp.array([5.])))
Executing function
First call:  4.0
Second call:  5.0
Executing function
Third call, different type:  [5.]
g = 0.
def impure_uses_globals(x):
  return x + g

# JAX 在第一次运行时捕获全局变量的值
print ("First call: ", jit(impure_uses_globals)(4.))
g = 10.  # 更新全局

# 后续运行可能会静默使用全局变量的缓存值。
print ("Second call: ", jit(impure_uses_globals)(5.))

# 当参数的类型或形状发生变化时,JAX 会重新运行该 Python 函数。
# 这将最终读取全局变量的最新值。
print ("Third call, different type: ", jit(impure_uses_globals)(jnp.array([4.])))
First call:  4.0
Second call:  5.0
Third call, different type:  [14.]
g = 0.
def impure_saves_global(x):
  global g
  g = x
  return x

# JAX 会在参数为特殊跟踪值的情况下,对转换后的函数执行一次运行。
print ("First call: ", jit(impure_saves_global)(4.))
print ("Saved global: ", g)  # 保存的全局变量具有一个内部的JAX值
First call:  4.0
Saved global:  Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>

一个Python函数即使在内部实际使用了有状态的对象,只要它不读取或写入外部状态,也可以是功能上纯的:

def pure_uses_internal_state(x):
  state = dict(even=0, odd=0)
  for i in range(10):
    state['even' if i % 2 == 0 else 'odd'] += x
  return state['even'] + state['odd']

print(jit(pure_uses_internal_state)(5.))
50.0

不建议在任何想要使用 jit 的 JAX 函数中或在任何控制流原语中使用迭代器。原因是迭代器是一个 Python 对象,它引入了状态来获取下一个元素。因此,它与 JAX 的函数式编程模型不兼容。在下面的代码中,提供了一些不正确尝试在 JAX 中使用迭代器的示例。大多数情况下,它们会返回错误,但有些则会产生意外的结果。

import jax.numpy as jnp
from jax import make_jaxpr

# 松弛.fori_循环
array = jnp.arange(10)
print(lax.fori_loop(0, 10, lambda i,x: x+array[i], 0)) # 预期结果 45
iterator = iter(range(10))
print(lax.fori_loop(0, 10, lambda i,x: x+next(iterator), 0)) # 意外结果 0

# 松弛扫描
def func11(arr, extra):
    ones = jnp.ones(arr.shape)
    def body(carry, aelems):
        ae1, ae2 = aelems
        return (carry + ae1 * ae2 + extra, carry)
    return lax.scan(body, 0., (arr, ones))
make_jaxpr(func11)(jnp.arange(16), 5.)
# make_jaxpr(func11)(iter(range(16)), 5.) 抛出错误

# 松弛条件
array_operand = jnp.array([0.])
lax.cond(True, lambda x: x+1, lambda x: x-1, array_operand)
iter_operand = iter(range(10))
# 松弛条件(True, lambda x: next(x)+1, lambda x: next(x)-1, iter_operand) 抛出错误
45
0

🔪 就地更新#

在 Numpy 中,你习惯于这样做:

numpy_array = np.zeros((3,3), dtype=np.float32)
print("original array:")
print(numpy_array)

# 就地变异更新
numpy_array[1, :] = 1.0
print("updated array:")
print(numpy_array)
original array:
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
updated array:
[[0. 0. 0.]
 [1. 1. 1.]
 [0. 0. 0.]]

如果我们尝试就地更新一个JAX设备数组,我们会遇到一个__错误__!(☉_☉)

%xmode Minimal
Exception reporting mode: Minimal
jax_array = jnp.zeros((3,3), dtype=jnp.float32)

# JAX数组的就地更新将导致错误!
jax_array[1, :] = 1.0
TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html

允许变量原地变异会使程序分析和转换变得困难。JAX要求程序是纯函数。

相反,JAX提供了一种_函数式_数组更新方法,使用JAX数组上的.at属性

️⚠️ 在 jit 编译的代码以及 lax.while_looplax.fori_loop 中,__切片__的__大小__不能是参数 values 的函数,而只能是参数 shapes 的函数——切片的起始索引没有这样的限制。有关此限制的更多信息,请参见下面的 控制流 部分。

数组更新: x.at[idx].set(y)#

例如,上述更新可以写为:

updated_array = jax_array.at[1, :].set(1.0)
print("updated array:\n", updated_array)
updated array:
 [[0. 0. 0.]
 [1. 1. 1.]
 [0. 0. 0.]]

JAX的数组更新函数与其NumPy版本不同,采用的是不在原地操作。也就是说,更新后的数组作为新数组返回,而原始数组不会因更新而被修改。

print("original array unchanged:\n", jax_array)
original array unchanged:
 [[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]

然而,在__jit__编译的代码中,如果x.at[idx].set(y)的__输入值__x没有被重用,编译器将优化数组更新为_就地_进行。

通过其他操作更新数组#

索引数组的更新不仅仅限于覆盖值。例如,我们可以进行索引加法,如下所示:

print("original array:")
jax_array = jnp.ones((5, 6))
print(jax_array)

new_jax_array = jax_array.at[::2, 3:].add(7.)
print("new array post-addition:")
print(new_jax_array)
original array:
[[1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]]
new array post-addition:
[[1. 1. 1. 8. 8. 8.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 8. 8. 8.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 8. 8. 8.]]

有关索引数组更新的更多细节,请参阅.at 属性的文档

🔪 超出边界的索引#

在Numpy中,当你在数组的边界之外进行索引时,通常会抛出错误,就像这样:

np.arange(10)[11]
IndexError: index 11 is out of bounds for axis 0 with size 10

然而,从在加速器上运行的代码中抛出错误可能是困难或不可能的。因此,JAX 必须为越界索引选择某种非错误行为(类似于无效浮点运算的结果是 NaN)。当索引操作是数组索引更新(例如 index_add 或类似 scatter 的原语)时,越界索引的更新将被跳过;当操作是数组索引检索(例如 NumPy 索引或类似 gather 的原语)时,索引会被限制在数组的边界内,因为__必须__返回某些东西。例如,从这个索引操作中将返回数组的最后一个值:

jnp.arange(10)[11]
Array(9, dtype=int32)

如果您想对越界索引的行为进行更细粒度的控制,可以使用ndarray.at的可选参数;例如:

jnp.arange(10.0).at[11].get()
Array(9., dtype=float32)
jnp.arange(10.0).at[11].get(mode='fill', fill_value=jnp.nan)
Array(nan, dtype=float32)

请注意,由于这种索引检索的行为,像 jnp.nanargminjnp.nanargmax 这样的函数对于由 NaN 组成的切片返回 -1,而 Numpy 会抛出错误。

还要注意,由于上述两种行为不是彼此的逆操作,反向模式自动微分(将索引更新转换为索引检索,反之亦然)不会保留超出边界索引的语义。因此,将 JAX 中的超出边界索引视为未定义行为 可能是个好主意。

🔪 非数组输入:NumPy与JAX#

NumPy通常乐于接受Python列表或元组作为其API函数的输入:

np.sum([1, 2, 3])
np.int64(6)

JAX在这方面有所不同,通常会返回一个有帮助的错误:

jnp.sum([1, 2, 3])
TypeError: sum requires ndarray or scalar arguments, got <class 'list'> at position 0.

这是一个故意的设计选择,因为将列表或元组传递给被跟踪的函数可能会导致静默的性能下降,而这种下降可能会很难被检测到。

例如,考虑以下允许列表输入的 jnp.sum 的宽松版本:

def permissive_sum(x):
  return jnp.sum(jnp.array(x))

x = list(range(10))
permissive_sum(x)
Array(45, dtype=int32)

输出是我们所期望的,但这隐藏了潜在的性能问题。在 JAX 的追踪和 JIT 编译模型中,Python 列表或元组中的每个元素都被视为一个独立的 JAX 变量,并且被单独处理并推送到设备。这可以在上面的 permissive_sum 函数的 jaxpr 中看到:

make_jaxpr(permissive_sum)(x)
{ lambda ; a:i32[] b:i32[] c:i32[] d:i32[] e:i32[] f:i32[] g:i32[] h:i32[] i:i32[]
    j:i32[]. let
    k:i32[] = convert_element_type[new_dtype=int32 weak_type=False] a
    l:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
    m:i32[] = convert_element_type[new_dtype=int32 weak_type=False] c
    n:i32[] = convert_element_type[new_dtype=int32 weak_type=False] d
    o:i32[] = convert_element_type[new_dtype=int32 weak_type=False] e
    p:i32[] = convert_element_type[new_dtype=int32 weak_type=False] f
    q:i32[] = convert_element_type[new_dtype=int32 weak_type=False] g
    r:i32[] = convert_element_type[new_dtype=int32 weak_type=False] h
    s:i32[] = convert_element_type[new_dtype=int32 weak_type=False] i
    t:i32[] = convert_element_type[new_dtype=int32 weak_type=False] j
    u:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] k
    v:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] l
    w:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] m
    x:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] n
    y:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] o
    z:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] p
    ba:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] q
    bb:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] r
    bc:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] s
    bd:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] t
    be:i32[10] = concatenate[dimension=0] u v w x y z ba bb bc bd
    bf:i32[] = reduce_sum[axes=(0,)] be
  in (bf,) }

列表中的每个条目被视为单独的输入,这导致跟踪和编译的开销随着列表大小线性增长。为了防止这样的意外,JAX 避免将列表和元组隐式转换为数组。

如果您想将元组或列表传递给 JAX 函数,可以先将其显式转换为数组:

jnp.sum(jnp.array(x))
Array(45, dtype=int32)

🔪 随机数#

如果所有由于糟糕的 rand() 导致结果存疑的科学论文都从图书馆书架上消失,那么每个书架上都会留下一个和你拳头差不多大的空隙。 - 《数值食谱》

随机数生成器和状态#

你习惯于使用来自numpy和其他库的_有状态_伪随机数生成器(PRNG),它们在内部隐藏了许多细节,为你提供一个随时可用的伪随机源:

print(np.random.random())
print(np.random.random())
print(np.random.random())
0.4679941601783879
0.5704016660558697
0.9030377811467778

在底层,numpy使用了Mersenne Twister伪随机数生成器来支持其伪随机函数。该伪随机数生成器的周期为\(2^{19937}-1\),并且在任何时刻可以通过__624个32位无符号整数__和一个__位置__来描述,后者指示已经使用了多少“熵”。

np.random.seed(0)
rng_state = np.random.get_state()
# 打印(rng_状态)
# --> ('MT19937', array([0, 1, 1812433255, 1900727105, 1208447044,
# 2481403966, 4042607538, 337614300, ... 还有614个数字...,
# 3048484911, 1796872496], dtype=uint32), 624, 0, 0.0)

这个伪随机状态向量在需要随机数时,会自动在后台更新,每次“消耗”梅森旋转状态向量中的2个uint32。

_ = np.random.uniform()
rng_state = np.random.get_state()
#打印(rng_状态)
# --> ('MT19937', array([2443250962, 1093594115, 1878467924,
# ..., 2648828502, 1678096082], dtype=uint32), 2, 0, 0.0)

# Let's exhaust the entropy in this PRNG statevector
for i in range(311):
  _ = np.random.uniform()
rng_state = np.random.get_state()
#打印(rng_状态)
# --> ('MT19937', array([2443250962, 1093594115, 1878467924,
#       ..., 2648828502, 1678096082], dtype=uint32), 624, 0, 0.0)

# Next call iterates the RNG state for a new batch of fake "entropy".
_ = np.random.uniform()
rng_state = np.random.get_state()
# print(rng_state)
# --> ('MT19937', array([1499117434, 2949980591, 2242547484,
# 4162027047, 3277342478], 数据类型为 uint32), 2, 0, 0.0)

魔法伪随机数生成器(PRNG)状态的问题在于,很难推理它是如何在不同的线程、进程和设备中使用和更新的,并且在熵产生和消耗的细节对最终用户隐藏时,出错的可能性是_非常容易_的。

梅森旋转算法(Mersenne Twister)PRNG 也被发现有许多问题,它有一个较大的 2.5kB 状态大小,这导致了问题重重的初始化问题。它在现代 BigCrush 测试中失败,并且速度通常较慢。

JAX 随机数生成器 (PRNG)#

JAX实现了一种_显式_的伪随机数生成器(PRNG),其中熵的生成和消耗通过显式传递和迭代PRNG状态来处理。JAX使用了一种现代的 Threefry计数器基础的PRNG,它是__可分叉__的。也就是说,它的设计允许我们__分叉__PRNG状态,以便用于并行随机生成。

随机状态由我们称之为__密钥__的特殊数组元素描述:

key = random.key(0)
key
Array((), dtype=key<fry>) overlaying:
[0 0]

JAX 的随机函数从伪随机数生成器状态产生伪随机数,但 不会 改变状态!

重复使用相同的状态会导致 悲伤单调,剥夺最终用户的 赋予生命的混沌:

print(random.normal(key, shape=(1,)))
print(key)
# 不不不!
print(random.normal(key, shape=(1,)))
print(key)
[-0.20584226]
Array((), dtype=key<fry>) overlaying:
[0 0]
[-0.20584226]
Array((), dtype=key<fry>) overlaying:
[0 0]

相反,我们__分割__伪随机数生成器(PRNG),以便在每次需要新的伪随机数时生成可用的__子密钥__:

print("old key", key)
key, subkey = random.split(key)
normal_pseudorandom = random.normal(subkey, shape=(1,))
print(r"    \---SPLIT --> new key   ", key)
print(r"             \--> new subkey", subkey, "--> normal", normal_pseudorandom)
old key Array((), dtype=key<fry>) overlaying:
[0 0]
\---SPLIT --> new key    Array((), dtype=key<fry>) overlaying:
[4146024105  967050713]
             \--> new subkey Array((), dtype=key<fry>) overlaying:
[2718843009 1272950319] --> normal [-1.2515389]

我们传播__密钥__并在需要新的随机数时生成新的__子密钥__:

print("old key", key)
key, subkey = random.split(key)
normal_pseudorandom = random.normal(subkey, shape=(1,))
print(r"    \---SPLIT --> new key   ", key)
print(r"             \--> new subkey", subkey, "--> normal", normal_pseudorandom)
old key Array((), dtype=key<fry>) overlaying:
[4146024105  967050713]
    \---SPLIT --> new key    Array((), dtype=key<fry>) overlaying:
[2384771982 3928867769]
             \--> new subkey Array((), dtype=key<fry>) overlaying:
[1278412471 2182328957] --> normal [-0.58665055]

我们可以同时生成多个 子密钥

key, *subkeys = random.split(key, 4)
for subkey in subkeys:
  print(random.normal(subkey, shape=(1,)))
[-0.37533438]
[0.98645043]
[0.14553197]

🔪 控制流#

✔ Python 控制流 + 自动微分 ✔#

如果你只想对你的 python 函数应用 grad,你可以使用常规的 python 控制流结构,毫无问题,就像你在使用 Autograd(或 Pytorch 或 TF Eager)一样。

def f(x):
  if x < 3:
    return 3. * x ** 2
  else:
    return -4 * x

print(grad(f)(2.))  # 好的!
print(grad(f)(4.))  # 好的!
12.0
-4.0

Python 控制流 + JIT#

使用 jit 进行控制流更加复杂,并且默认情况下具有更多限制。

这样做是可行的:

@jit
def f(x):
  for i in range(3):
    x = 2 * x
  return x

print(f(3))
24

所以这也是:

@jit
def g(x):
  y = 0.
  for i in range(x.shape[0]):
    y = y + x[i]
  return y

print(g(jnp.array([1., 2., 3.])))
6.0

但这并不是,至少在默认情况下:

@jit
def f(x):
  if x < 3:
    return 3. * x ** 2
  else:
    return -4 * x

# 这行不通!
f(2)
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function f at /var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22743/4086156896.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

这是怎么回事!?

当我们对一个函数进行jit编译时,通常希望编译一个针对许多不同参数值有效的函数版本,以便可以缓存并重用已编译的代码。这样,我们就不必在每次函数评估时重新编译。

例如,如果我们在数组jnp.array([1., 2., 3.], jnp.float32)上评估一个@jit函数,我们可能希望编译一些代码,以便可以重用这些代码来评估jnp.array([4., 5., 6.], jnp.float32)以节省编译时间。

为了获取适用于许多不同参数值的Python代码,JAX以表示可能输入集合的_抽象值_对其进行跟踪。这里有多个不同的抽象层次,不同的变换使用不同的抽象层次。

默认情况下,jit会在ShapedArray抽象层次上跟踪代码,其中每个抽象值代表具有固定形状和数据类型的所有数组值的集合。例如,如果我们使用抽象值ShapedArray((3,), jnp.float32)进行跟踪,我们就可以得到一个可以在相应的数组集合中对任何具体值重用的函数视图。这意味着我们可以节省编译时间。

但这里有一个权衡:如果我们在一个没有特定具体值的ShapedArray((), jnp.float32)上跟踪Python函数,当我们遇到像if x < 3这样的行时,表达式x < 3会计算为一个抽象ShapedArray((), jnp.bool_),代表集合{True, False}。当Python试图将其强制转换为具体的TrueFalse时,我们会遇到错误:我们不知道该选择哪个分支,因此无法继续跟踪!这个权衡在于,使用更高层次的抽象,我们获得了对Python代码的更一般的视图(从而节省了重新编译),但我们需要对Python代码施加更多约束以完成跟踪。

好消息是,您可以自己控制这个权衡。通过让jit在更细化的抽象值上跟踪,您可以放宽跟踪约束。例如,使用jitstatic_argnums参数,我们可以指定在某些参数的具体值上进行跟踪。下面是这个示例函数的再次展示:

def f(x):
  if x < 3:
    return 3. * x ** 2
  else:
    return -4 * x

f = jit(f, static_argnums=(0,))

print(f(2.))
12.0

这是另一个例子,这次涉及一个循环:

def f(x, n):
  y = 0.
  for i in range(n):
    y = y + x[i]
  return y

f = jit(f, static_argnums=(1,))

f(jnp.array([2., 3., 4.]), 2)
Array(5., dtype=float32)

实际上,这个循环被静态展开。JAX 还可以在更高层次的抽象上进行追踪,比如 Unshaped,但这目前不是任何变换的默认设置。

️⚠️ 具有参数__值__依赖形状的函数

这些控制流问题以一种更微妙的方式出现:我们希望使用 jit 的数值函数无法根据参数 专门化内部数组的形状(根据参数 形状 专门化是可以的)。作为一个简单的例子,我们来创建一个输出恰好依赖于输入变量 length 的函数。

def example_fun(length, val):
  return jnp.ones((length,)) * val
# 未压缩的运行良好
print(example_fun(5, 4))
[4. 4. 4. 4. 4.]
bad_example_jit = jit(example_fun)
# 这将失败:
bad_example_jit(10, 4)
TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>,).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function example_fun at /var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22743/1934221560.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument length.
# static_argnums 指示 JAX 在指定位置的参数发生变化时重新编译:
good_example_jit = jit(example_fun, static_argnums=(0,))
# 首先编译
print(good_example_jit(10, 4))
# 重新编译
print(good_example_jit(5, 4))
[4. 4. 4. 4. 4. 4. 4. 4. 4. 4.]
[4. 4. 4. 4. 4.]

static_argnums 在我们的示例中,如果 length 很少变化,可能会很方便,但如果变化频繁,就会引发灾难!

最后,如果你的函数有全局副作用,JAX 的追踪器可能会导致奇怪的事情发生。一个常见的陷阱是在 jit 函数内部尝试打印数组:

@jit
def f(x):
  print(x)
  y = 2 * x
  print(y)
  return y
f(2)
Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
Array(4, dtype=int32, weak_type=True)

结构化控制流原语#

在JAX中有更多的控制流选项。假设你想避免重新编译,但仍然希望使用可追踪的控制流,并且避免展开大型循环。那么你可以使用这4个结构化控制流原语:

  • lax.cond 可微分

  • lax.while_loop 正向模式可微分

  • lax.fori_loop 一般来说是 正向模式可微分; 如果端点是静态的,则 正向和反向模式可微分

  • lax.scan 可微分

cond#

python等价:

def cond(pred, true_fun, false_fun, operand):
  if pred:
    return true_fun(operand)
  else:
    return false_fun(operand)
from jax import lax

operand = jnp.array([0.])
lax.cond(True, lambda x: x+1, lambda x: x-1, operand)
# --> 数组([1.], dtype=float32)
lax.cond(False, lambda x: x+1, lambda x: x-1, operand)
# --> 数组([-1.], dtype=float32)
Array([-1.], dtype=float32)

jax.lax 提供了两个其他函数,使得可以在动态条件上进行分支:

  • lax.select 类似于 lax.cond 的批处理版本,其选择以预先计算的数组形式表达,而不是函数。

  • lax.switch 类似于 lax.cond,但允许在任意数量的可调用选择之间切换。

此外,jax.numpy 为这些函数提供了多个 numpy 风格的接口:

  • jnp.where 有三个参数,是 lax.select 的 numpy 风格封装。

  • jnp.piecewiselax.switch 的 numpy 风格封装,但它根据布尔条件列表进行切换,而不是单个标量索引。

  • jnp.select 的 API 类似于 jnp.piecewise,但选择是以预先计算的数组形式给出的,而不是函数。它的实现是通过多次调用 lax.select 来完成的。

while_loop#

Python等价:

def while_loop(cond_fun, body_fun, init_val):
  val = init_val
  while cond_fun(val):
    val = body_fun(val)
  return val
init_val = 0
cond_fun = lambda x: x < 10
body_fun = lambda x: x+1
lax.while_loop(cond_fun, body_fun, init_val)
# --> 数组(10, dtype=int32)
Array(10, dtype=int32, weak_type=True)

fori_loop#

Python 等价代码:

def fori_loop(start, stop, body_fun, init_val):
  val = init_val
  for i in range(start, stop):
    val = body_fun(i, val)
  return val
init_val = 0
start = 0
stop = 10
body_fun = lambda i,x: x+i
lax.fori_loop(start, stop, body_fun, init_val)
# --> 数组([45], dtype=int32)
Array(45, dtype=int32, weak_type=True)

摘要#

\[\begin{split} \begin{array} {r|rr} \hline \ \textrm{构造} & \textrm{jit} & \textrm{grad} \\ \hline \ \textrm{if} & ❌ & ✔ \\ \textrm{for} & ✔* & ✔\\ \textrm{while} & ✔* & ✔\\ \textrm{lax.cond} & ✔ & ✔\\ \textrm{lax.while_loop} & ✔ & \textrm{fwd}\\ \textrm{lax.fori_loop} & ✔ & \textrm{fwd}\\ \textrm{lax.scan} & ✔ & ✔\\ \hline \end{array} \end{split}\]

\(\ast\) = 参数--无关循环条件 - 展开循环

🔪 动态形状#

JAX代码在 jax.jitjax.vmapjax.grad 等变换中使用时要求所有输出数组和中间数组具有静态形状:也就是说,形状不能依赖于其他数组中的值。

例如,如果您要实现自己的 jnp.nansum 版本,您可能会从以下内容开始:

def nansum(x):
  mask = ~jnp.isnan(x)  # 布尔掩码选择非NaN值
  x_without_nans = x[mask]
  return x_without_nans.sum()

在不考虑JIT和其他转换的情况下,这按预期工作:

x = jnp.array([1, 2, jnp.nan, 3, 4])
print(nansum(x))
10.0

如果您尝试对这个函数应用 jax.jit 或其他变换,将会出现错误:

jax.jit(nansum)(x)
NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[5])

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError

问题在于x_without_nans的大小依赖于x中的值,换句话说,它的大小是动态的。 在JAX中,通常可以通过其他手段绕过对动态大小数组的需求。 例如,在这里可以使用jnp.where的三参数形式将NaN值替换为零,从而在避免动态形状的同时计算出相同的结果:

@jax.jit
def nansum_2(x):
  mask = ~jnp.isnan(x)  # 布尔掩码选择非NaN值
  return jnp.where(mask, x, 0).sum()

print(nansum_2(x))
10.0

在其他存在动态形状数组的情况下,也可以使用类似的技巧。

🔪 NaNs#

调试 NaNs#

如果您想追踪在您的函数或梯度中 NaNs 出现的位置,可以通过以下方式启用 NaN 检查器:

  • 设置环境变量 JAX_DEBUG_NANS=True

  • 在主文件的顶部添加 jax.config.update("jax_debug_nans", True)

  • 在主文件中添加 jax.config.parse_flags_with_absl(),然后使用命令行标志设置选项,例如 --jax_debug_nans=True

这将导致在生成 NaN 时计算立即出错。开启这个选项会在每个由 XLA 生成的浮点类型值上添加 NaN 检查。这意味着每个不在 @jit 之下的原始操作的值都会被拉回主机并作为 ndarrays 进行检查。对于处于 @jit 下的代码,每个 @jit 函数的输出都会被检查,如果存在 NaN,函数将以逐步优化的模式重新运行,从而有效地一次移除一个层级的 @jit

可能会出现一些棘手的情况,比如在 @jit 下出现的 NaN,但在逐步优化模式中没有产生。在这种情况下,您将看到警告信息,但您的代码将继续执行。

如果在梯度评估的反向传播过程中产生了 NaN,当在堆栈跟踪中抛出异常时,您将进入 backward_pass 函数,该函数本质上是一个简单的 jaxpr 解释器,反向遍历原始操作序列。在下面的示例中,我们使用命令行 env JAX_DEBUG_NANS=True ipython 启动了一个 ipython repl,然后运行了这个:

在 [1]: import jax.numpy as jnp

在 [2]: jnp.divide(0., 0.)
---------------------------------------------------------------------------
浮点错误                        回溯(最近的调用最后)
<ipython-input-2-f2e2c413b437> 在 <module>()
----> 1 jnp.divide(0., 0.)

.../jax/jax/numpy/lax_numpy.pyc 在 divide(x1, x2)
    343     return floor_divide(x1, x2)
    344   else:
--> 345     return true_divide(x1, x2)
    346
    347

.../jax/jax/numpy/lax_numpy.pyc 在 true_divide(x1, x2)
    332   x1, x2 = _promote_shapes(x1, x2)
    333   return lax.div(lax.convert_element_type(x1, result_dtype),
--> 334                  lax.convert_element_type(x2, result_dtype))
    335
    336

.../jax/jax/lax.pyc 在 div(x, y)
    244 def div(x, y):
    245   r"""逐元素除法: :math:`x \over y`."""
--> 246   return div_p.bind(x, y)
    247
    248 def rem(x, y):

... 堆栈追踪 ...

.../jax/jax/interpreters/xla.pyc 在 handle_result(device_buffer)
    103         py_val = device_buffer.to_py()
    104         if np.any(np.isnan(py_val)):
--> 105           raise FloatingPointError("无效值")
    106         else:
    107           return Array(device_buffer, *result_shape)

浮点错误: 无效值

捕获到生成的nan。通过运行%debug,我们可以获得一个事后调试器。这对于@jit下的函数也有效,如下面的示例所示。

In [4]: from jax import jit

In [5]: @jit
   ...: def f(x, y):
   ...:     a = x * y
   ...:     b = (x + y) / (x - y)
   ...:     c = a + 2
   ...:     return a + b * c
   ...:

In [6]: x = jnp.array([2., 0.])

In [7]: y = jnp.array([3., 0.])

In [8]: f(x, y)
在 jit 函数的输出中遇到无效值。正在调用未优化的版本。
---------------------------------------------------------------------------
FloatingPointError                        Traceback (most recent call last)
<ipython-input-8-811b7ddb3300> in <module>()
----> 1 f(x, y)

 ... stack trace ...

<ipython-input-5-619b39acbaac> in f(x, y)
      2 def f(x, y):
      3     a = x * y
----> 4     b = (x + y) / (x - y)
      5     c = a + 2
      6     return a + b * c

.../jax/jax/numpy/lax_numpy.pyc in divide(x1, x2)
    343     return floor_divide(x1, x2)
    344   else:
--> 345     return true_divide(x1, x2)
    346
    347

.../jax/jax/numpy/lax_numpy.pyc in true_divide(x1, x2)
    332   x1, x2 = _promote_shapes(x1, x2)
    333   return lax.div(lax.convert_element_type(x1, result_dtype),
--> 334                  lax.convert_element_type(x2, result_dtype))
    335
    336

.../jax/jax/lax.py in div(x, y)
    244 def div(x, y):
    245   r"""逐元素除法: :math:`x \over y`。"""
--> 246   return div_p.bind(x, y)
    247
    248 def rem(x, y):

 ... stack trace ...

当这段代码在@jit函数的输出中看到nan时,它会调用未优化的代码,因此我们仍然会获得明确的堆栈跟踪。我们可以使用%debug运行事后调试器,以检查所有值并找出错误。

⚠️ 如果您不是在调试,那么不应该启用NaN检查器,因为这会导致大量设备和主机之间的往返以及性能回退!

⚠️ NaN检查器不适用于pmap。要调试pmap代码中的nan,可以尝试将pmap替换为vmap

🔪 双精度(64位)#

目前,JAX默认强制使用单精度数字,以减轻Numpy API偏向于积极将操作数提升到double的倾向。这对于许多机器学习应用程序来说是期望的行为,但这可能会让你感到意外!

x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)
x.dtype
/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22743/1032186105.py:1: UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'>  is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)
dtype('float32')

要使用双精度数字,您需要在__启动时__设置jax_enable_x64配置变量。

有几种方法可以做到这一点:

  1. 您可以通过设置环境变量JAX_ENABLE_X64=True来启用64位模式。

  2. 您可以在启动时手动设置jax_enable_x64配置标志:

    # 再次强调,这仅在启动时有效!
    import jax
    jax.config.update("jax_enable_x64", True)
    
  3. 您可以使用absl.app.run(main)解析命令行标志:

    import jax
    jax.config.config_with_absl()
    
  4. 如果您希望JAX为您运行absl解析,即您不想执行absl.app.run(main),您可以改为使用:

    import jax
    if __name__ == '__main__':
      # 调用 jax.config.config_with_absl() *并且* 运行 absl 解析
      jax.config.parse_flags_with_absl()
    

请注意,#2-#4 适用于 JAX 的 任何 配置选项。

然后我们可以确认x64模式已启用,例如:

import jax
import jax.numpy as jnp
from jax import random

jax.config.update("jax_enable_x64", True)
x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)
x.dtype # --> dtype('float64')

注意事项#

⚠️ XLA 并不支持所有后端的 64 位卷积!

🔪 与NumPy的杂项差异#

尽管jax.numpy努力复制NumPy的API行为,但确实存在一些边角案例,其行为有所不同。 许多这样的案例在上面的部分中已详细讨论;在这里我们列出一些其他已知的API差异。

  • 对于二元操作,JAX的类型提升规则与NumPy使用的规则略有不同。有关更多细节,请参见类型提升语义

  • 在执行不安全类型转换时(即目标数据类型无法表示输入值的转换),JAX的行为可能依赖于后端,并且通常可能与NumPy的行为有所不同。NumPy允许在这些场景中通过casting参数控制结果(请参见np.ndarray.astype);而JAX并未提供此类配置,而是直接继承 XLA:ConvertElementType的行为。

    下面是一个不安全转换的示例,其中NumPy和JAX的结果不同:

    >>> np.arange(254.0, 258.0).astype('uint8')
    array([254, 255,   0,   1], dtype=uint8)
    
    >>> jnp.arange(254.0, 258.0).astype('uint8')
    Array([254, 255, 255, 255], dtype=uint8)
    
    

    这种不匹配通常发生在将浮点数转换为整数类型或反之亦然时,尤其是极端值。

完。#

如果这里没有涵盖的内容导致您悲伤和咬牙切齿,请告诉我们,我们会扩展这些介绍性_建议_!