自动微分食谱#

在 Colab 中打开 在 Kaggle 中打开

JAX 具有相当通用的自动微分系统。在本笔记本中,我们将介绍一系列有趣的自动微分思路,供您根据自己的工作进行选择,从基础知识开始。

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

key = random.key(0)

梯度#

grad 开始#

您可以使用 grad 对函数进行求导:

grad_tanh = grad(jnp.tanh)
print(grad_tanh(2.0))
0.070650816

grad 取一个函数并返回一个函数。如果你有一个 Python 函数 f,它评估数学函数 \(f\),那么 grad(f) 是一个 Python 函数,用于评估数学函数 \(\nabla f\)。这意味着 grad(f)(x) 表示值 \(\nabla f(x)\)

由于 grad 在函数上运作,你可以将它应用于它自己的输出,以便进行任意次数的微分:

print(grad(grad(jnp.tanh))(2.0))
print(grad(grad(grad(jnp.tanh)))(2.0))
-0.13621868
0.25265405

让我们来看看在一个线性逻辑回归模型中使用grad计算梯度。首先,进行设置:

def sigmoid(x):
    return 0.5 * (jnp.tanh(x / 2) + 1)

# 输出标签为真的概率。
def predict(W, b, inputs):
    return sigmoid(jnp.dot(inputs, W) + b)

# 构建一个玩具数据集。
inputs = jnp.array([[0.52, 1.12,  0.77],
                   [0.88, -1.08, 0.15],
                   [0.52, 0.06, -1.30],
                   [0.74, -2.49, 1.39]])
targets = jnp.array([True, True, False, True])

# 训练损失是训练样本的负对数似然。
def loss(W, b):
    preds = predict(W, b, inputs)
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -jnp.sum(jnp.log(label_probs))

# 初始化随机模型系数
key, W_key, b_key = random.split(key, 3)
W = random.normal(W_key, (3,))
b = random.normal(b_key, ())

使用grad函数及其argnums参数对位置参数进行函数求导。

# 对第一个位置参数求 `loss` 的导数:
W_grad = grad(loss, argnums=0)(W, b)
print('W_grad', W_grad)

# 由于argnums=0是默认值,因此这会执行相同操作:
W_grad = grad(loss)(W, b)
print('W_grad', W_grad)

# 但我们也可以选择不同的价值观,并舍弃这个关键词:
b_grad = grad(loss, 1)(W, b)
print('b_grad', b_grad)

# 包含元组值
W_grad, b_grad = grad(loss, (0, 1))(W, b)
print('W_grad', W_grad)
print('b_grad', b_grad)
W_grad [-0.16965583 -0.8774644  -1.4901346 ]
W_grad [-0.16965583 -0.8774644  -1.4901346 ]
b_grad -0.29227245
W_grad [-0.16965583 -0.8774644  -1.4901346 ]
b_grad -0.29227245

这个 grad API 与斯皮瓦克经典著作《流形微积分》(1965年)中的优秀符号直接对应,该符号也用于萨斯曼和威兹登的《经典力学的结构与解释》(2015年)和他们的《泛函微分几何》(2013年)。这两本书都是开放获取的。特别参见《泛函微分几何》的“前言”部分,以了解对这种符号的辩护。

本质上,当使用 argnums 参数时,如果 f 是一个用于评估数学函数 \(f\) 的 Python 函数,则 Python 表达式 grad(f, i) 评估为一个用于评估 \(\partial_i f\) 的 Python 函数。

针对嵌套列表、元组和字典的求导#

相对于标准的Python容器进行微分非常简单,因此可以随意使用元组、列表和字典(以及任意嵌套)。

def loss2(params_dict):
    preds = predict(params_dict['W'], params_dict['b'], inputs)
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -jnp.sum(jnp.log(label_probs))

print(grad(loss2)({'W': W, 'b': b}))
{'W': Array([-0.16965583, -0.8774644 , -1.4901346 ], dtype=float32), 'b': Array(-0.29227245, dtype=float32)}

您可以 注册您自己的容器类型,以便与不仅仅是 grad,而是所有 JAX 转换(jitvmap 等)一起使用。

使用 value_and_grad 评估一个函数及其梯度#

另一个方便的函数是 value_and_grad,用于高效地计算函数的值以及其梯度的值:

from jax import value_and_grad
loss_value, Wb_grad = value_and_grad(loss, (0, 1))(W, b)
print('loss value', loss_value)
print('loss value', loss(W, b))
loss value 3.0519388
loss value 3.0519388

检查数值差异#

导数的一个优点是它们可以通过有限差分简单地检查:

# 设置有限差分计算的步长
eps = 1e-4

# 检查 b_grad 与标量有限差分
b_grad_numerical = (loss(W, b + eps / 2.) - loss(W, b - eps / 2.)) / eps
print('b_grad_numerical', b_grad_numerical)
print('b_grad_autodiff', grad(loss, 1)(W, b))

# 在随机方向上使用有限差分法检查W_grad
key, subkey = random.split(key)
vec = random.normal(subkey, W.shape)
unitvec = vec / jnp.sqrt(jnp.vdot(vec, vec))
W_grad_numerical = (loss(W + eps / 2. * unitvec, b) - loss(W - eps / 2. * unitvec, b)) / eps
print('W_dirderiv_numerical', W_grad_numerical)
print('W_dirderiv_autodiff', jnp.vdot(grad(loss)(W, b), unitvec))
b_grad_numerical -0.29325485
b_grad_autodiff -0.29227245
W_dirderiv_numerical -0.2002716
W_dirderiv_autodiff -0.19909117

JAX提供了一个简单的方便函数,它基本上执行相同的操作,但可以检查您喜欢的任意阶数的微分:

from jax.test_util import check_grads
check_grads(loss, (W, b), order=2)  # 检查至二阶导数

使用gradgrad计算Hessian-向量积#

我们可以使用高阶grad来构建Hessian-向量积函数。(稍后我们将编写一个更高效的实现,该实现同时混合了前向模式和反向模式,但这个实现将只使用纯反向模式。)

Hessian-向量积函数在截断牛顿共轭梯度算法中对于最小化平滑凸函数非常有用,或者用于研究神经网络训练目标的曲率(例如,1234)。

对于一个标量值函数 \(f : \mathbb{R}^n \to \mathbb{R}\),其具有连续二阶导数(使得Hessian矩阵是对称的),在点 \(x \in \mathbb{R}^n\) 处的Hessian写作 \(\partial^2 f(x)\)。那么Hessian-向量积函数能够计算

\(\qquad v \mapsto \partial^2 f(x) \cdot v\)

对于任何 \(v \in \mathbb{R}^n\)

这里的关键是不要实例化完整的Hessian矩阵:如果 \(n\) 很大,可能在数百万或数十亿的上下文中与神经网络相关,那么存储它可能是不可能的。

幸运的是,grad已经为我们提供了一种编写高效Hessian-向量积函数的方法。我们只需使用等式

\(\qquad \partial^2 f (x) v = \partial [x \mapsto \partial f(x) \cdot v] = \partial g(x)\)

其中 \(g(x) = \partial f(x) \cdot v\) 是一个新的标量值函数,它将点乘 \(f\)\(x\) 处的梯度与向量 \(v\)。注意,我们始终只区分向量值参数的标量值函数,这正是我们知道grad是高效的地方。

在JAX代码中,我们可以这样写:

def hvp(f, x, v):
    return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)

这个例子表明,你可以自由使用词法闭包,而JAX将永远不会受到干扰或困惑。

我们将在接下来的几个单元中检查这个实现,一旦我们了解如何计算稠密的Hessian矩阵。我们还将编写一个更好的版本,它同时使用正向模式和反向模式。

使用 jacfwdjacrev 求雅可比和赫essian矩阵#

您可以使用 jacfwdjacrev 函数计算完整的雅可比矩阵:

from jax import jacfwd, jacrev

# 将函数从权重矩阵分离到预测结果
f = lambda W: predict(W, b, inputs)

J = jacfwd(f)(W)
print("jacfwd result, with shape", J.shape)
print(J)

J = jacrev(f)(W)
print("jacrev result, with shape", J.shape)
print(J)
jacfwd result, with shape (4, 3)
[[ 0.05981758  0.12883787  0.08857603]
 [ 0.04015916 -0.04928625  0.00684531]
 [ 0.12188288  0.01406341 -0.3047072 ]
 [ 0.00140431 -0.00472531  0.00263782]]
jacrev result, with shape (4, 3)
[[ 0.05981757  0.12883787  0.08857603]
 [ 0.04015916 -0.04928625  0.00684531]
 [ 0.12188289  0.01406341 -0.3047072 ]
 [ 0.00140431 -0.00472531  0.00263782]]

这两个函数计算相同的值(直到机器数值精度),但在实现上有所不同:jacfwd 使用正向模式自动微分,对于“高“雅可比矩阵(输出大于输入)更有效,而 jacrev 使用反向模式,对于“宽”雅可比矩阵(输入大于输出)更有效。对于接近方形的矩阵,jacfwd 可能比 jacrev 更具优势。

您也可以使用 jacfwdjacrev 处理容器类型:

def predict_dict(params, inputs):
    return predict(params['W'], params['b'], inputs)

J_dict = jacrev(predict_dict)({'W': W, 'b': b}, inputs)
for k, v in J_dict.items():
    print("Jacobian from {} to logits is".format(k))
    print(v)
Jacobian from W to logits is
[[ 0.05981757  0.12883787  0.08857603]
 [ 0.04015916 -0.04928625  0.00684531]
 [ 0.12188289  0.01406341 -0.3047072 ]
 [ 0.00140431 -0.00472531  0.00263782]]
Jacobian from b to logits is
[0.11503381 0.04563541 0.23439017 0.00189771]

有关正向模式和反向模式的更多细节,以及如何尽可能高效地实现 jacfwdjacrev,请继续阅读!

使用这两个函数的组合,我们可以计算密集的Hessian矩阵:

def hessian(f):
    return jacfwd(jacrev(f))

H = hessian(f)(W)
print("hessian, with shape", H.shape)
print(H)
hessian, with shape (4, 3, 3)
[[[ 0.02285465  0.04922541  0.03384247]
  [ 0.04922541  0.10602397  0.07289147]
  [ 0.03384247  0.07289147  0.05011288]]

 [[-0.03195215  0.03921401 -0.00544639]
  [ 0.03921401 -0.04812629  0.00668421]
  [-0.00544639  0.00668421 -0.00092836]]

 [[-0.01583708 -0.00182736  0.03959271]
  [-0.00182736 -0.00021085  0.00456839]
  [ 0.03959271  0.00456839 -0.09898177]]

 [[-0.00103524  0.00348343 -0.00194457]
  [ 0.00348343 -0.01172127  0.0065432 ]
  [-0.00194457  0.0065432  -0.00365263]]]

这个形状是有道理的:如果我们从一个函数 \(f : \mathbb{R}^n \to \mathbb{R}^m\) 开始,那么在一个点 \(x \in \mathbb{R}^n\),我们期望得到以下形状:

  • \(f(x) \in \mathbb{R}^m\),函数 \(f\)\(x\) 处的值,

  • \(\partial f(x) \in \mathbb{R}^{m \times n}\),在 \(x\) 处的雅可比矩阵,

  • \(\partial^2 f(x) \in \mathbb{R}^{m \times n \times n}\),在 \(x\) 处的海森矩阵,

等等。

为了实现 hessian,我们可以使用 jacfwd(jacrev(f))jacrev(jacfwd(f)),或者两者的任何组合。但是前向-反向通常是最有效的。这是因为在内层雅可比计算中,我们通常是在对一个宽雅可比(可能像损失函数 \(f : \mathbb{R}^n \to \mathbb{R}\))进行求导,而在外层雅可比计算中,我们对一个具有方形雅可比的函数进行求导(因为 \(\nabla f : \mathbb{R}^n \to \mathbb{R}^n\)),这就是前向模式胜出的地方。

这是如何制作的:两个基础自自动微分函数#

雅可比向量积 (JVPs,前向模式自动微分)#

JAX 包含对前向模式和反向模式自动微分的高效和通用实现。熟悉的 grad 函数是基于反向模式的,但为了说明这两种模式之间的区别,以及何时每种模式可以有用,我们需要一些数学背景。

数学中的 JVPs#

在数学上,给定一个函数 \(f : \mathbb{R}^n \to \mathbb{R}^m\),在输入点 \(x \in \mathbb{R}^n\) 处评估的 \(f\) 的雅可比矩阵,记作 \(\partial f(x)\),通常被视为 \(\mathbb{R}^m \times \mathbb{R}^n\) 中的一个矩阵:

\(\qquad \partial f(x) \in \mathbb{R}^{m \times n}\)

但我们也可以将 \(\partial f(x)\) 视为一个线性映射,它将 \(f\) 在点 \(x\) 的定义域的切空间(这只是 \(\mathbb{R}^n\) 的另一个副本)映射到 \(f\) 在点 \(f(x)\) 的值域的切空间(一个 \(\mathbb{R}^m\) 的副本):

\(\qquad \partial f(x) : \mathbb{R}^n \to \mathbb{R}^m\)

这个映射被称为 \(f\)\(x\) 处的 推前映射。雅可比矩阵只是这个线性映射在标准基下的矩阵。

如果我们不固定一个特定的输入点 \(x\),那么我们可以将函数 \(\partial f\) 视为首先接受一个输入点并返回该输入点处的雅可比线性映射:

\(\qquad \partial f : \mathbb{R}^n \to \mathbb{R}^n \to \mathbb{R}^m\)

特别地,我们可以将其解构,使得给定输入点 \(x \in \mathbb{R}^n\) 和一个切向量 \(v \in \mathbb{R}^n\),我们得到一个在 \(\mathbb{R}^m\) 中的输出切向量。我们将这种从 \((x, v)\) 对到输出切向量的映射称为 雅可比向量积,并将其表示为

\(\qquad (x, v) \mapsto \partial f(x) v\)

JAX 代码中的 JVPs#

在 Python 代码中,JAX 的 jvp 函数模型化了这种转换。给定一个评估 \(f\) 的 Python 函数,JAX 的 jvp 提供了一种评估 \((x, v) \mapsto (f(x), \partial f(x) v)\) 的 Python 函数的方法。

from jax import jvp

# 将函数从权重矩阵分离到预测结果
f = lambda W: predict(W, b, inputs)

key, subkey = random.split(key)
v = random.normal(subkey, W.shape)

# 沿着在 `W` 处求值的 `f` 推进向量 `v`
y, u = jvp(f, (W,), (v,))

类似Haskell的类型签名方面,我们可以写成:

jvp :: (a -> b) -> a -> T a -> (b, T b)

在文字上,jvp 接受三个参数:一个类型为 a -> b 的函数,一个类型为 a 的值,以及一个类型为 T a 的切向量值。它返回一个由一个类型为 b 的值和一个类型为 T b 的输出切向量组成的元组。

jvp转换的函数的评估方式与原始函数类似,但为了每个类型为 a 的原始值,它还会推送类型为 T a 的切线值。对于原始函数会应用的每一个基本数值操作,jvp 转换的函数都会为该基本操作执行一个“JVP 规则”,该规则不仅在原始值上评估基本操作,还在这些原始值上应用该基本操作的 JVP。

这种评估策略对计算复杂性有一些直接的影响:由于我们在计算过程中实时评估 JVP,因此我们不需要为后续存储任何东西,因此内存成本与计算深度无关。此外,jvp 转换的函数的 FLOP 成本大约是仅评估函数成本的 3 倍(例如,评估原始函数 sin(x) 的一个单位工作;线性化的一个单位,例如 cos(x);以及将线性化函数应用于向量的一个单位,例如 cos_x * v)。换句话说,对于固定的原始点 \(x\),我们可以以与评估 \(f\) 相似的边际成本评估 \(v \mapsto \partial f(x) \cdot v\)

这个内存复杂性听起来相当诱人!那么为什么我们在机器学习中很少看到前向模式呢?

为了回答这个问题,首先考虑如何使用 JVP 来构建完整的雅可比矩阵。如果我们将 JVP 应用于一个独热向量,它会展示雅可比矩阵的一列, correspondeing到我们输入的非零条目。因此,我们可以一次构建完整的雅可比矩阵,每一列的构建成本大约与一次函数评估相同。对于具有“高”雅可比矩阵的函数,这种方法是高效的,但对于“宽”雅可比矩阵则效率不高。

如果你在机器学习中进行基于梯度的优化,你可能想要最小化一个从参数在 \(\mathbb{R}^n\) 中到标量损失值在 \(\mathbb{R}\) 中的损失函数。这意味着该函数的雅可比矩阵是一个非常宽的矩阵:\(\partial f(x) \in \mathbb{R}^{1 \times n}\),我们通常将其识别为梯度向量 \(\nabla f(x) \in \mathbb{R}^n\)。逐列构建这个矩阵,每次调用的 FLOPs 大约与评估原始函数的 FLOPs 相似,确实看起来效率不高!特别是在训练神经网络时,\(f\) 是一个训练损失函数而 \(n\) 可能处于数百万或数十亿之间,这种方法根本无法扩展。

为了解决此类函数的问题,我们只需要使用反向模式。

向量-雅可比乘积 (VJPs,亦即反向模式自动微分)#

前向模式为我们提供了一个用于评估雅可比-向量乘积的函数,然后我们可以用它来逐列构建雅可比矩阵,而反向模式则是一种返回评估向量-雅可比乘积的函数(等同于雅可比-转置-向量乘积)的方法,我们可以用它来逐行构建雅可比矩阵。

数学中的VJPs#

我们再考虑一个函数 \(f : \mathbb{R}^n \to \mathbb{R}^m\)。 从我们针对JVP的符号开始,VJP的符号是相当简单的:

\(\qquad (x, v) \mapsto v \partial f(x)\)

其中 \(v\)\(f\) 在点 \(x\) 的余切空间中的一个元素(同构于另一个副本的 \(\mathbb{R}^m\))。在严谨的情况下,我们应该把 \(v\) 看作一个线性映射 \(v : \mathbb{R}^m \to \mathbb{R}\),而当我们写 \(v \partial f(x)\) 时,我们的意思是函数复合 \(v \circ \partial f(x)\),由于类型匹配,因为 \(\partial f(x) : \mathbb{R}^n \to \mathbb{R}^m\)。但在常见的情况下,我们可以将 \(v\) 视为 \(\mathbb{R}^m\) 中的一个向量,二者几乎可以互换使用,就像我们有时在“列向量”和“行向量”之间切换一样,而不会多加赘述。

通过这种识别,我们可以将VJP的线性部分看作JVP的线性部分的转置(或伴随共轭):

\(\qquad (x, v) \mapsto \partial f(x)^\mathsf{T} v\)

对于给定的点 \(x\),我们可以写出签名为

\(\qquad \partial f(x)^\mathsf{T} : \mathbb{R}^m \to \mathbb{R}^n\)

相应的余切空间上的映射通常被称为 \(f\)\(x\) 处的 拉回。 对于我们的目的,关键在于它从看起来像 \(f\) 的输出的东西,变换到看起来像 \(f\) 的输入的东西,就像我们从转置的线性函数中可能期待的那样。

JAX代码中的VJPs#

从数学回到Python,JAX函数 vjp 可以接受一个用于评估 \(f\) 的Python函数,并返回一个用于评估VJP \((x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))\) 的Python函数。

from jax import vjp

# 将函数从权重矩阵分离到预测结果
f = lambda W: predict(W, b, inputs)

y, vjp_fun = vjp(f, W)

key, subkey = random.split(key)
u = random.normal(subkey, y.shape)

# 沿 `f` 在 `W` 处求值,将协向量 `u` 拉回
v = vjp_fun(u)

类Haskell类型签名方面,我们可以写成

vjp :: (a -> b) -> a -> (b, CT b -> CT a)

这里我们使用CT a来表示类型a的余切空间。换句话说,vjp接收一个类型为a -> b的函数和一个类型为a的点,并返回一个包含类型为b的值和类型为CT b -> CT a的线性映射的元组。

这很好,因为它让我们可以逐行构建雅可比矩阵,对于计算\((x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))\)的FLOP成本大约仅是计算\(f\)成本的三倍。特别是,如果我们想要一个函数\(f : \mathbb{R}^n \to \mathbb{R}\)的梯度,我们只需进行一次调用。这就是为什么grad在基于梯度的优化中高效,即便对于像神经网络训练损失函数这样涉及数百万或数十亿参数的目标。

不过,这也有一个代价:尽管FLOP友好,内存的使用量与计算的深度成正比。此外,反向模式的实现通常比前向模式复杂,尽管JAX有一些窍门(这是未来笔记本的故事!)。

有关反向模式如何工作的更多信息,请参见2017年深度学习夏季学校的这个教程视频

向量值梯度与VJP#

如果你有兴趣进行向量值梯度计算(比如使用 tf.gradients):

from jax import vjp

def vgrad(f, x):
  y, vjp_fn = vjp(f, x)
  return vjp_fn(jnp.ones(y.shape))[0]

print(vgrad(lambda x: 3*x**2, jnp.ones((2, 2))))
[[6. 6.]
 [6. 6.]]

赫essian-向量积的正向模式和反向模式使用#

在前面的章节中,我们仅使用反向模式实现了Hessian-向量积函数(假设存在连续的二阶导数):

def hvp(f, x, v):
    return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)

这很高效,但我们可以通过结合正向模式和反向模式做得更好,并节省一些内存。

从数学上讲,给定一个需要求导的函数 \(f : \mathbb{R}^n \to \mathbb{R}\),一个点 \(x \in \mathbb{R}^n\) 用来线性化该函数,以及一个向量 \(v \in \mathbb{R}^n\),我们想要的 Hessian-向量积函数是

\((x, v) \mapsto \partial^2 f(x) v\)

考虑帮助函数 \(g : \mathbb{R}^n \to \mathbb{R}^n\),它被定义为 \(f\) 的导数(或梯度),即 \(g(x) = \partial f(x)\)。我们所需要的只是它的 JVP,因为它将给我们

\((x, v) \mapsto \partial g(x) v = \partial^2 f(x) v\)

我们可以几乎直接将其翻译成代码:

from jax import jvp, grad

# 前进-倒车
def hvp(f, primals, tangents):
  return jvp(grad(f), primals, tangents)[1]

更好的是,由于我们不需要直接调用 jnp.dot,这个 hvp 函数可以处理任意形状的数组和任意容器类型(例如以嵌套列表/字典/元组存储的向量),甚至不依赖于 jax.numpy

以下是一个使用示例:

def f(X):
  return jnp.sum(jnp.tanh(X)**2)

key, subkey1, subkey2 = random.split(key, 3)
X = random.normal(subkey1, (30, 40))
V = random.normal(subkey2, (30, 40))

ans1 = hvp(f, (X,), (V,))
ans2 = jnp.tensordot(hessian(f)(X), V, 2)

print(jnp.allclose(ans1, ans2, 1e-4, 1e-4))
True

您可能会考虑另一种写法,即使用反向覆盖前向:

# 反向-正向
def hvp_revfwd(f, primals, tangents):
  g = lambda primals: jvp(f, primals, tangents)[1]
  return grad(g)(primals)

这并不是那么好,因为正向模式的开销低于反向模式,而这里的外部求导算子必须对比内部更大的计算进行求导,因此将正向模式保持在外部是最有效的。

# 反向-反向,仅适用于单个参数
def hvp_revrev(f, primals, tangents):
  x, = primals
  v, = tangents
  return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)


print("Forward over reverse")
%timeit -n10 -r3 hvp(f, (X,), (V,))
print("Reverse over forward")
%timeit -n10 -r3 hvp_revfwd(f, (X,), (V,))
print("Reverse over reverse")
%timeit -n10 -r3 hvp_revrev(f, (X,), (V,))

print("Naive full Hessian materialization")
%timeit -n10 -r3 jnp.tensordot(hessian(f)(X), V, 2)
Forward over reverse
1.31 ms ± 111 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)
Reverse over forward
The slowest run took 5.08 times longer than the fastest. This could mean that an intermediate result is being cached.
3.67 ms ± 2.95 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
Reverse over reverse
The slowest run took 5.17 times longer than the fastest. This could mean that an intermediate result is being cached.
5.46 ms ± 4.47 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
Naive full Hessian materialization
7.52 ms ± 231 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)

组合 VJPs、JVPs 和 vmap#

雅可比矩阵和矩阵-雅可比乘积#

现在我们有了jvpvjp变换,可以逐个推送或拉回单个向量,我们可以使用JAX的vmap 变换 一次性推送和拉回整个基。特别是,我们可以利用这一点来编写快速的矩阵-雅可比和雅可比-矩阵乘积。

# 将函数从权重矩阵分离到预测结果
f = lambda W: predict(W, b, inputs)

# 将协向量 `m_i` 沿 `f` 回拉,并在 `W` 处求值,对所有 `i` 进行此操作。
# 首先,使用列表推导式遍历矩阵 M 中的每一行。
def loop_mjp(f, x, M):
    y, vjp_fun = vjp(f, x)
    return jnp.vstack([vjp_fun(mi) for mi in M])

# 现在,使用 vmap 构建一个计算,该计算执行单次快速矩阵-矩阵
# 乘法,而不是对向量-矩阵乘法的外层循环。
def vmap_mjp(f, x, M):
    y, vjp_fun = vjp(f, x)
    outs, = vmap(vjp_fun)(M)
    return outs

key = random.key(0)
num_covecs = 128
U = random.normal(key, (num_covecs,) + y.shape)

loop_vs = loop_mjp(f, W, M=U)
print('Non-vmapped Matrix-Jacobian product')
%timeit -n10 -r3 loop_mjp(f, W, M=U)

print('\nVmapped Matrix-Jacobian product')
vmap_vs = vmap_mjp(f, W, M=U)
%timeit -n10 -r3 vmap_mjp(f, W, M=U)

assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Matrix-Jacobian Products should be identical'
Non-vmapped Matrix-Jacobian product
57.6 ms ± 1.3 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)

Vmapped Matrix-Jacobian product
1.58 ms ± 104 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)
/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22781/1724443100.py:8: DeprecationWarning: vstack requires ndarray or scalar arguments, got <class 'tuple'> at position 0. In a future JAX release this will be an error.
  return jnp.vstack([vjp_fun(mi) for mi in M])
def loop_jmp(f, W, M):
    # jvp 会立即将原值和切线值作为元组返回,
    # so we'll compute and select the tangents in a list comprehension
    return jnp.vstack([jvp(f, (W,), (mi,))[1] for mi in M])

def vmap_jmp(f, W, M):
    _jvp = lambda s: jvp(f, (W,), (s,))[1]
    return vmap(_jvp)(M)

num_vecs = 128
S = random.normal(key, (num_vecs,) + W.shape)

loop_vs = loop_jmp(f, W, M=S)
print('Non-vmapped Jacobian-Matrix product')
%timeit -n10 -r3 loop_jmp(f, W, M=S)
vmap_vs = vmap_jmp(f, W, M=S)
print('\nVmapped Jacobian-Matrix product')
%timeit -n10 -r3 vmap_jmp(f, W, M=S)

assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Jacobian-Matrix products should be identical'
Non-vmapped Jacobian-Matrix product
71.4 ms ± 11.3 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)

Vmapped Jacobian-Matrix product
822 μs ± 28 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)

jacfwdjacrev 的实现#

现在我们已经看到了快速雅可比矩阵和矩阵雅可比乘积,推测如何编写 jacfwdjacrev 也就不难了。我们只需使用相同的技术一次性推导或回拉整个标准基(与单位矩阵同构)。

from jax import jacrev as builtin_jacrev

def our_jacrev(f):
    def jacfun(x):
        y, vjp_fun = vjp(f, x)
        # 使用vmap进行矩阵-雅可比积运算。
        # 在这里,矩阵是以欧几里得基底表示的,因此我们得到了所有。
        # 一次性输入雅可比矩阵的条目。
        J, = vmap(vjp_fun, in_axes=0)(jnp.eye(len(y)))
        return J
    return jacfun

assert jnp.allclose(builtin_jacrev(f)(W), our_jacrev(f)(W)), 'Incorrect reverse-mode Jacobian results!'
from jax import jacfwd as builtin_jacfwd

def our_jacfwd(f):
    def jacfun(x):
        _jvp = lambda s: jvp(f, (x,), (s,))[1]
        Jt = vmap(_jvp, in_axes=1)(jnp.eye(len(x)))
        return jnp.transpose(Jt)
    return jacfun

assert jnp.allclose(builtin_jacfwd(f)(W), our_jacfwd(f)(W)), 'Incorrect forward-mode Jacobian results!'

有趣的是,Autograd 无法做到这一点。我们在 Autograd 中的反向模式 jacobian 实现不得不通过一个外循环的 map 一次拉回一个向量。一次性将一个向量推送通过计算的效率远低于使用 vmap 将所有向量批量处理。

另一个Autograd无法做到的是jit。有趣的是,无论你在需要微分的函数中使用多少Python动态特性,我们总是可以在计算的线性部分使用jit。例如:

def f(x):
    try:
        if x < 3:
            return 2 * x ** 3
        else:
            raise ValueError
    except ValueError:
        return jnp.pi * x

y, f_vjp = vjp(f, 4.)
print(jit(f_vjp)(1.))
(Array(3.1415927, dtype=float32, weak_type=True),)

复数与微分#

JAX在复数和微分方面表现出色。为了支持全纯和非全纯微分,思考JVP和VJP的概念是很有帮助的。

考虑一个复到复的函数\(f: \mathbb{C} \to \mathbb{C}\),并将其与一个相应的函数\(g: \mathbb{R}^2 \to \mathbb{R}^2\)关联起来,

def f(z):
  x, y = jnp.real(z), jnp.imag(z)
  return u(x, y) + v(x, y) * 1j

def g(x, y):
  return (u(x, y), v(x, y))

也就是说,我们已经将 \(f(z) = u(x, y) + v(x, y) i\) 进行分解,其中 \(z = x + y i\),并将 \(\mathbb{C}\)\(\mathbb{R}^2\) 进行等同以得到 \(g\)

由于 \(g\) 仅涉及实数输入和输出,我们已经知道如何为其编写雅可比-向量乘积,比如给定一个切向量 \((c, d) \in \mathbb{R}^2\),即

\(\begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} \begin{bmatrix} c \\ d \end{bmatrix}\)

为了获得对于切向量 \(c + di \in \mathbb{C}\) 应用原始函数 \(f\) 的 JVP,我们只需使用相同的定义,并将结果识别为另一个复数,

\(\partial f(x + y i)(c + d i) = \begin{matrix} \begin{bmatrix} 1 & i \end{bmatrix} \\ ~ \end{matrix} \begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} \begin{bmatrix} c \\ d \end{bmatrix}\)

这就是我们对 \(\mathbb{C} \to \mathbb{C}\) 函数的 JVP 定义!请注意,\(f\) 是否全纯并不重要:JVP 是明确的。

请提供要翻译的ipynb文件中的markdown内容。

def check(seed):
  key = random.key(seed)

  # u和v的随机系数
  key, subkey = random.split(key)
  a, b, c, d = random.uniform(subkey, (4,))

  def fun(z):
    x, y = jnp.real(z), jnp.imag(z)
    return u(x, y) + v(x, y) * 1j

  def u(x, y):
    return a * x + b * y

  def v(x, y):
    return c * x + d * y

  # 原始点
  key, subkey = random.split(key)
  x, y = random.uniform(subkey, (2,))
  z = x + y * 1j

  # 切向量
  key, subkey = random.split(key)
  c, d = random.uniform(subkey, (2,))
  z_dot = c + d * 1j

  # 检查雅可比向量积
  _, ans = jvp(fun, (z,), (z_dot,))
  expected = (grad(u, 0)(x, y) * c +
              grad(u, 1)(x, y) * d +
              grad(v, 0)(x, y) * c * 1j+
              grad(v, 1)(x, y) * d * 1j)
  print(jnp.allclose(ans, expected))
check(0)
check(1)
check(2)
True
True
True

关于VJP如何?我们做的非常相似:对于一个余切向量 \(c + di \in \mathbb{C}\),我们定义函数\(f\)的VJP为

\( (c + di)^* \; \partial f(x + y i) = \begin{matrix} \begin{bmatrix} c & -d \end{bmatrix} \\ ~ \end{matrix} \begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \\ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} \begin{bmatrix} 1 \\ -i \end{bmatrix} \)

那些负号是为了处理复共轭,以及我们正在使用余向量的事实。

这里是VJP规则的检查:

def check(seed):
  key = random.key(seed)

  # u和v的随机系数
  key, subkey = random.split(key)
  a, b, c, d = random.uniform(subkey, (4,))

  def fun(z):
    x, y = jnp.real(z), jnp.imag(z)
    return u(x, y) + v(x, y) * 1j

  def u(x, y):
    return a * x + b * y

  def v(x, y):
    return c * x + d * y

  # 原始点
  key, subkey = random.split(key)
  x, y = random.uniform(subkey, (2,))
  z = x + y * 1j

  # 余切向量
  key, subkey = random.split(key)
  c, d = random.uniform(subkey, (2,))
  z_bar = jnp.array(c + d * 1j)  # 用于数据类型控制

  # 检查垂直跳跃表现
  _, fun_vjp = vjp(fun, z)
  ans, = fun_vjp(z_bar)
  expected = (grad(u, 0)(x, y) * c +
              grad(v, 0)(x, y) * (-d) +
              grad(u, 1)(x, y) * c * (-1j) +
              grad(v, 1)(x, y) * (-d) * (-1j))
  assert jnp.allclose(ans, expected, atol=1e-5, rtol=1e-5)
check(0)
check(1)
check(2)

关于方便的包装函数,如 gradjacfwdjacrev

对于 \(\mathbb{R} \to \mathbb{R}\) 的函数,回顾一下我们定义 grad(f)(x)vjp(f, x)[1](1.0),这是可行的,因为将 VJP 应用到 1.0 值上可以揭示梯度(即雅可比,或导数)。我们可以对 \(\mathbb{C} \to \mathbb{R}\) 函数做同样的事情:我们仍然可以使用 1.0 作为共切向量,并且我们得到一个复杂数结果,总结了完整的雅可比:

def f(z):
  x, y = jnp.real(z), jnp.imag(z)
  return x**2 + y**2

z = 3. + 4j
grad(f)(z)
Array(6.-8.j, dtype=complex64)

对于一般的 \(\mathbb{C} \to \mathbb{C}\) 函数,雅可比矩阵有4个实值自由度(如上面的2x2雅可比矩阵所示),因此我们无法希望在一个复数中表示它们。但是对于全纯函数,我们可以做到!全纯函数恰恰是具有特殊性质的 \(\mathbb{C} \to \mathbb{C}\) 函数,其导数可以表示为一个单一的复数。(柯西-黎曼方程 确保上述2x2雅可比矩阵具有复平面中缩放和旋转矩阵的特殊形式,即一个复数在乘法下的作用。)我们可以通过一次对 vjp 的调用以及一个 1.0 的共变量来揭示这个复杂的数字。

因为这仅对全纯函数有效,所以为了使用这个技巧,我们需要向JAX保证我们的函数是全纯的;否则,当对复杂输出函数使用 grad 时,JAX将引发错误:

def f(z):
  return jnp.sin(z)

z = 3. + 4j
grad(f, holomorphic=True)(z)
Array(-27.034946-3.8511534j, dtype=complex64, weak_type=True)

所有的 holomorphic=True 选项所做的只是禁用当输出为复值时的错误提示。我们仍然可以在函数不是全纯的情况下写 holomorphic=True,但是我们得到的答案并不会表示完整的雅可比矩阵。相反,它将是我们仅仅丢弃输出的虚部的函数的雅可比矩阵:

def f(z):
  return jnp.conjugate(z)

z = 3. + 4j
grad(f, holomorphic=True)(z)  # f 实际上并不是全纯的!
Array(1.-0.j, dtype=complex64, weak_type=True)

有一些关于grad如何工作的有用结论:

  1. 我们可以在全纯的 \(\mathbb{C} \to \mathbb{C}\) 函数上使用 grad

  2. 我们可以使用 grad 来优化 \(f : \mathbb{C} \to \mathbb{R}\) 函数,例如复数参数 x 的实值损失函数,通过朝着 grad(f)(x) 的共轭方向迈步。

  3. 如果我们有一个 \(\mathbb{R} \to \mathbb{R}\) 函数,恰好在内部使用了一些复杂值操作(其中一些必须是非全纯的,例如用于卷积的 FFT),那么 grad 仍然有效,并且我们会得到与仅使用实数值的实现相同的结果。

无论如何,JVP 和 VJP 始终是明确的。如果我们想计算一个非全纯的 \(\mathbb{C} \to \mathbb{C}\) 函数的完整雅可比矩阵,我们可以使用 JVP 或 VJP 来实现!

您应该期望复杂数在JAX中可以随处使用。以下是对复杂矩阵的Cholesky分解进行微分的示例:

A = jnp.array([[5.,    2.+3j,    5j],
              [2.-3j,   7.,  1.+7j],
              [-5j,  1.-7j,    12.]])

def f(X):
    L = jnp.linalg.cholesky(X)
    return jnp.sum((L - jnp.sin(L))**2)

grad(f, holomorphic=True)(A)
Array([[-0.75341946 +0.j       , -3.0509028 -10.940545j ,
         5.9896846  +3.5423026j],
       [-3.0509028 +10.940545j , -8.904491   +0.j       ,
        -5.1351523  -6.559373j ],
       [ 5.9896846  -3.5423026j, -5.1351523  +6.559373j ,
         0.01320427 +0.j       ]], dtype=complex64)

更高级的自动微分#

在这个笔记本中,我们逐步探讨了一些简单的,然后是逐渐更复杂的,JAX中的自动微分应用。我们希望你现在觉得在JAX中求导既简单又强大。

在这里还有很多其他的自动微分技巧和功能。这些主题我们没有涵盖,但希望在“高级自动微分食谱”中介绍,包括:

  • 高斯-牛顿向量积,一次线性化

  • 自定义VJP和JVP

  • 在固定点处的高效导数

  • 使用随机Hessian-向量积估计Hessian的迹

  • 仅使用反向模式自动微分进行前向模式自动微分

  • 针对自定义数据类型的导数

  • 检查点(用于高效反向模式的二项式检查点,而不是模型快照)

  • 使用雅可比预累积优化VJP