自动微分#

在本节中,您将了解自动微分(autodiff)在JAX中的基本应用。JAX拥有一个非常通用的自动微分系统。计算梯度是现代机器学习方法的关键部分,本教程将带您了解一些入门级的自动微分主题,例如:

  • 自动微分-求梯度

  • 自动微分线性逻辑回归

  • 自动微分-嵌套列表-元组和字典

  • 使用 JAX 进行自动微分评估 - value_and_grad

  • 自动微分检查与数值差异

确保也查看 高级自动微分 教程,了解更多高级主题。

虽然理解自动微分在“幕后”如何工作对于在大多数情况下使用 JAX 来说并不是至关重要的,但你仍被鼓励查看这个相当易懂的 视频 以更深入地了解正在发生的事情。

使用 jax.grad 计算梯度#

在 JAX 中,你可以使用 jax.grad() 变换来求标量值函数的导数:

import jax
import jax.numpy as jnp
from jax import grad

grad_tanh = grad(jnp.tanh)
print(grad_tanh(2.0))
0.070650816

jax.grad() 接受一个函数并返回一个函数。如果你有一个评估数学函数 \(f\) 的 Python 函数 f,那么 jax.grad(f) 是一个评估数学函数 \( abla f\) 的 Python 函数。这意味着 grad(f)(x) 表示 \( abla f(x)\) 的值。

由于 jax.grad() 作用于函数,你可以将其应用于其自身的输出,以进行任意多次的微分:

print(grad(grad(jnp.tanh))(2.0))
print(grad(grad(grad(jnp.tanh)))(2.0))
-0.13621868
0.25265405

JAX 的自动微分使得计算高阶导数变得容易,因为计算导数的函数本身也是可微的。因此,高阶导数就像堆叠变换一样简单。这可以在单变量情况下说明:

函数 \(f(x) = x^3 + 2x^2 - 3x + 1\) 的导数可以计算为:

f = lambda x: x**3 + 2*x**2 - 3*x + 1

dfdx = jax.grad(f)

函数 \(f\) 的高阶导数为:

\[::\]

在JAX中计算这些就像链接 jax.grad() 函数一样简单:

d2fdx = jax.grad(dfdx)
d3fdx = jax.grad(d2fdx)
d4fdx = jax.grad(d3fdx)

\(x=1\) 处评估上述表达式将得到:

\[::\]

使用 JAX:

print(dfdx(1.))
print(d2fdx(1.))
print(d3fdx(1.))
print(d4fdx(1.))
4.0
10.0
6.0
0.0

2. Computing gradients in a linear logistic regression#

下一个示例展示了如何在逻辑回归模型中使用 jax.grad() 计算梯度。首先,进行设置:

key = jax.random.key(0)

def sigmoid(x):
  return 0.5 * (jnp.tanh(x / 2) + 1)

# Outputs probability of a label being true.
def predict(W, b, inputs):
  return sigmoid(jnp.dot(inputs, W) + b)

# Build a toy dataset.
inputs = jnp.array([[0.52, 1.12,  0.77],
                    [0.88, -1.08, 0.15],
                    [0.52, 0.06, -1.30],
                    [0.74, -2.49, 1.39]])
targets = jnp.array([True, True, False, True])

# Training loss is the negative log-likelihood of the training examples.
def loss(W, b):
  preds = predict(W, b, inputs)
  label_probs = preds * targets + (1 - preds) * (1 - targets)
  return -jnp.sum(jnp.log(label_probs))

# Initialize random model coefficients
key, W_key, b_key = jax.random.split(key, 3)
W = jax.random.normal(W_key, (3,))
b = jax.random.normal(b_key, ())

使用带有 argnums 参数的 jax.grad() 函数来对函数关于位置参数进行微分。

# Differentiate `loss` with respect to the first positional argument:
W_grad = grad(loss, argnums=0)(W, b)
print(f'{W_grad=}')

# Since argnums=0 is the default, this does the same thing:
W_grad = grad(loss)(W, b)
print(f'{W_grad=}')

# But you can choose different values too, and drop the keyword:
b_grad = grad(loss, 1)(W, b)
print(f'{b_grad=}')

# Including tuple values
W_grad, b_grad = grad(loss, (0, 1))(W, b)
print(f'{W_grad=}')
print(f'{b_grad=}')
W_grad=Array([-0.16965583, -0.8774644 , -1.4901346 ], dtype=float32)
W_grad=Array([-0.16965583, -0.8774644 , -1.4901346 ], dtype=float32)
b_grad=Array(-0.29227245, dtype=float32)
W_grad=Array([-0.16965583, -0.8774644 , -1.4901346 ], dtype=float32)
b_grad=Array(-0.29227245, dtype=float32)

The jax.grad() API 与 Spivak 的经典著作《流形上的微积分》(1965) 中的优秀符号有直接对应关系,该符号也被 Sussman 和 Wisdom 在其《经典力学的结构与解释》(2015) 和《函数微分几何》(2013) 中使用。这两本书都是开放获取的。特别参见《函数微分几何》的“序言”部分,以了解对该符号的辩护。

本质上,当使用 argnums 参数时,如果 f 是一个用于评估数学函数 \(f\) 的 Python 函数,那么 Python 表达式 jax.grad(f, i) 将评估为一个用于评估 \(\partial_i f\) 的 Python 函数。

3. Differentiating with respect to nested lists, tuples, and dicts#

由于 JAX 的 PyTree 抽象(参见 使用 pytrees),对标准 Python 容器的微分操作可以正常工作,因此你可以随意使用元组、列表和字典(以及任意嵌套)。

继续前面的例子:

def loss2(params_dict):
    preds = predict(params_dict['W'], params_dict['b'], inputs)
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -jnp.sum(jnp.log(label_probs))

print(grad(loss2)({'W': W, 'b': b}))
{'W': Array([-0.16965583, -0.8774644 , -1.4901346 ], dtype=float32), 'b': Array(-0.29227245, dtype=float32)}

你可以创建 自定义 pytree 节点 来不仅与 jax.grad() 一起工作,还可以与其他 JAX 变换(如 jax.jit()jax.vmap() 等)一起工作。

4. Evaluating a function and its gradient using jax.value_and_grad#

另一个方便的功能是 jax.value_and_grad(),它可以在一次传递中高效地计算函数值及其梯度值。

继续前面的例子:

loss_value, Wb_grad = jax.value_and_grad(loss, (0, 1))(W, b)
print('loss value', loss_value)
print('loss value', loss(W, b))
loss value 3.0519388
loss value 3.0519388

5. Checking against numerical differences#

关于导数的一个优点是,它们可以通过有限差分直接检查。

继续前面的例子:

# Set a step size for finite differences calculations
eps = 1e-4

# Check b_grad with scalar finite differences
b_grad_numerical = (loss(W, b + eps / 2.) - loss(W, b - eps / 2.)) / eps
print('b_grad_numerical', b_grad_numerical)
print('b_grad_autodiff', grad(loss, 1)(W, b))

# Check W_grad with finite differences in a random direction
key, subkey = jax.random.split(key)
vec = jax.random.normal(subkey, W.shape)
unitvec = vec / jnp.sqrt(jnp.vdot(vec, vec))
W_grad_numerical = (loss(W + eps / 2. * unitvec, b) - loss(W - eps / 2. * unitvec, b)) / eps
print('W_dirderiv_numerical', W_grad_numerical)
print('W_dirderiv_autodiff', jnp.vdot(grad(loss)(W, b), unitvec))
b_grad_numerical -0.29325485
b_grad_autodiff -0.29227245
W_dirderiv_numerical -0.2002716
W_dirderiv_autodiff -0.19909117

JAX 提供了一个简单的便利函数,它基本上做同样的事情,但可以检查到你喜欢的任何阶的微分:

from jax.test_util import check_grads

check_grads(loss, (W, b), order=2)  # check up to 2nd order derivatives

下一步#

本文件中涉及的概念在 JAX 后端的具体实现,在 高级自动微分 教程中有更高级和详细的解释。一些功能,如 高级自动微分-自定义导数规则,依赖于对高级自动微分的理解,因此如果你感兴趣,请查看 高级自动微分 教程中的相关部分。