使用 jax.checkpoint(即 jax.remat)控制自动微分的保存值#

import jax
import jax.numpy as jnp

摘要#

使用 jax.checkpoint 装饰器(别名为 jax.remat)与 jax.grad 配合,控制在前向传播中保存哪些中间值,以及在反向传播中重新计算哪些中间值,从而在内存和浮点运算(FLOPs)之间进行权衡。

不要错过 实用说明,讨论 jax.checkpoint 如何与 jax.jit 互动。

在未使用 jax.checkpoint 的情况下,jax.grad(f)(x) 的前向传播会保存雅可比系数和其他中间值的值,以便在反向传播中使用。我们将这些保存的值称为 残差

def g(W, x):
  y = jnp.dot(W, x)
  return jnp.sin(y)

def f(W1, W2, W3, x):
  x = g(W1, x)
  x = g(W2, x)
  x = g(W3, x)
  return x

W1 = jnp.ones((5, 4))
W2 = jnp.ones((6, 5))
W3 = jnp.ones((7, 6))
x = jnp.ones(4)

# Inspect the 'residual' values to be saved on the forward pass
# 如果我们对 `jax.grad(f)(W1, W2, W3, x)` 进行评估
from jax.ad_checkpoint import print_saved_residuals
jax.ad_checkpoint.print_saved_residuals(f, W1, W2, W3, x)
f32[5,4] from the argument 'W1'
f32[6,5] from the argument 'W2'
f32[7,6] from the argument 'W3'
f32[4] from the argument 'x'
f32[5] output of sin from <ipython-input-4-f510dde58e22>:3 (g)
f32[5] output of cos from <ipython-input-4-f510dde58e22>:3 (g)
f32[6] output of sin from <ipython-input-4-f510dde58e22>:3 (g)
f32[6] output of cos from <ipython-input-4-f510dde58e22>:3 (g)
f32[7] output of cos from <ipython-input-4-f510dde58e22>:3 (g)

通过将 jax.checkpoint 应用到子函数,无论是作为装饰器还是在特定应用点上,我们迫使 JAX 不保存该子函数的任何残余。相反,只有被 jax.checkpoint 装饰的函数的输入可能会被保存,而在反向传播过程中消耗的任何残余会根据这些输入在需要时重新计算。

def f2(W1, W2, W3, x):
  x = jax.checkpoint(g)(W1, x)
  x = jax.checkpoint(g)(W2, x)
  x = jax.checkpoint(g)(W3, x)
  return x

jax.ad_checkpoint.print_saved_residuals(f2, W1, W2, W3, x)
f32[5,4] from the argument 'W1'
f32[6,5] from the argument 'W2'
f32[7,6] from the argument 'W3'
f32[4] from the argument 'x'
f32[5] output of sin from <ipython-input-4-f510dde58e22>:3 (g)
f32[6] output of sin from <ipython-input-4-f510dde58e22>:3 (g)

在这里,两个 sin 应用的值被保存,因为它们是后续应用 jax.checkpoint 装饰的 g 函数的参数,而输入到 jax.checkpoint 装饰的函数中可能会被保存。但 cos 应用的值没有被保存。

为了控制哪些值是可保存的,而不必编辑要微分的函数的定义,可以使用重材料化 策略。以下是一个示例,只有在没有批次维度的情况下才保存 dot 操作的结果(因为它们通常是 FLOP 绑定的,因此值得保存而不是重新计算):

f3 = jax.checkpoint(f, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
jax.ad_checkpoint.print_saved_residuals(f3, W1, W2, W3, x)
f32[5,4] from the argument 'W1'
f32[6,5] from the argument 'W2'
f32[7,6] from the argument 'W3'
f32[4] from the argument 'x'
f32[5] output of dot_general from <ipython-input-4-f510dde58e22>:2 (g)
f32[6] output of dot_general from <ipython-input-4-f510dde58e22>:2 (g)
f32[7] output of dot_general from <ipython-input-4-f510dde58e22>:2 (g)

您还可以使用策略引用您通过 jax.ad_checkpoint.checkpoint_name 命名的中间值:

from jax.ad_checkpoint import checkpoint_name

def f4(W1, W2, W3, x):
  x = checkpoint_name(g(W1, x), name='a')
  x = checkpoint_name(g(W2, x), name='b')
  x = checkpoint_name(g(W3, x), name='c')
  return x

f4 = jax.checkpoint(f4, policy=jax.checkpoint_policies.save_only_these_names('a'))
jax.ad_checkpoint.print_saved_residuals(f4, W1, W2, W3, x)
f32[5,4] from the argument 'W1'
f32[6,5] from the argument 'W2'
f32[7,6] from the argument 'W3'
f32[4] from the argument 'x'
f32[5] named 'a' from <ipython-input-7-fc0ed1c14b8d>:4 (f4)

在玩这些玩具示例时,我们可以使用本笔记本中定义的 print_fwd_bwd 工具更仔细地查看发生了什么:

from jax.tree_util import tree_flatten, tree_unflatten

from rich.console import Console
from rich.table import Table
import rich.text

def print_fwd_bwd(f, *args, **kwargs) -> None:
  args, in_tree = tree_flatten((args, kwargs))

  def f_(*args):
    args, kwargs = tree_unflatten(in_tree, args)
    return f(*args, **kwargs)

  fwd = jax.make_jaxpr(lambda *args: jax.vjp(f_, *args))(*args).jaxpr

  y, f_vjp = jax.vjp(f_, *args)
  res, in_tree = tree_flatten(f_vjp)

  def g_(*args):
    *res, y = args
    f_vjp = tree_unflatten(in_tree, res)
    return f_vjp(y)

  bwd = jax.make_jaxpr(g_)(*res, y).jaxpr

  table = Table(show_header=False, show_lines=True, padding=(1, 2, 0, 2), box=None)
  table.add_row("[bold green]forward computation:",
                "[bold green]backward computation:")
  table.add_row(rich.text.Text.from_ansi(str(fwd)),
                rich.text.Text.from_ansi(str(bwd)))
  console = Console(width=240, force_jupyter=True)
  console.print(table)

def _renderable_repr(self):
  return self.html
rich.jupyter.JupyterRenderable._repr_html_ = _renderable_repr
# 未使用 jax.checkpoint:
print_fwd_bwd(f, W1, W2, W3, x)
                                                                                                                                                                      
  forward computation:                                                        backward computation:                                                                   
                                                                                                                                                                      
  { lambda ; a:f32[5,4] b:f32[6,5] c:f32[7,6] d:f32[4]. let                   { lambda ; a:f32[7] b:f32[6] c:f32[7,6] d:f32[6] e:f32[5] f:f32[6,5] g:f32[5] h:f32[4]  
      e:f32[5] = dot_general[dimension_numbers=(([1], [0]), ([], []))] a d        i:f32[5,4] j:f32[7]. let                                                            
      f:f32[5] = sin e                                                            k:f32[7] = mul j a                                                                  
      g:f32[5] = cos e                                                            l:f32[6] = dot_general[dimension_numbers=(([0], [0]), ([], []))] k c                
      h:f32[6] = dot_general[dimension_numbers=(([1], [0]), ([], []))] b f        m:f32[7,6] = dot_general[dimension_numbers=(([], []), ([], []))] k b                
      i:f32[6] = sin h                                                            n:f32[6] = mul l d                                                                  
      j:f32[6] = cos h                                                            o:f32[5] = dot_general[dimension_numbers=(([0], [0]), ([], []))] n f                
      k:f32[7] = dot_general[dimension_numbers=(([1], [0]), ([], []))] c i        p:f32[6,5] = dot_general[dimension_numbers=(([], []), ([], []))] n e                
      l:f32[7] = sin k                                                            q:f32[5] = mul o g                                                                  
      m:f32[7] = cos k                                                            r:f32[4] = dot_general[dimension_numbers=(([0], [0]), ([], []))] q i                
    in (l, m, i, c, j, f, b, g, d, a) }                                           s:f32[5,4] = dot_general[dimension_numbers=(([], []), ([], []))] q h                
                                                                                in (s, p, m, r) }                                                                     
# 使用 jax.checkpoint 并设置策略为 jax.checkpoint_policies.dots_with_no_batch_dims_saveable:
print_fwd_bwd(f3, W1, W2, W3, x)
                                                                                                                                                                             
  forward computation:                                                        backward computation:                                                                          
                                                                                                                                                                             
  { lambda ; a:f32[5,4] b:f32[6,5] c:f32[7,6] d:f32[4]. let                   { lambda ; a:f32[5] b:f32[6] c:f32[7] d:f32[5,4] e:f32[6,5] f:f32[7,6] g:f32[4] h:f32[7]. let  
      e:f32[5] = dot_general[dimension_numbers=(([1], [0]), ([], []))] a d        i:f32[5,4] j:f32[6,5] k:f32[7,6] l:f32[4] = remat2[                                        
      f:f32[5] = sin e                                                              differentiated=True                                                                      
      g:f32[6] = dot_general[dimension_numbers=(([1], [0]), ([], []))] b f          jaxpr={ lambda ; m:f32[5] n:f32[6] o:f32[7] p:f32[5,4] q:f32[6,5] r:f32[7,6]             
      h:f32[6] = sin g                                                                  s:f32[4] t:f32[7]. let                                                               
      i:f32[7] = dot_general[dimension_numbers=(([1], [0]), ([], []))] c h              u:f32[5] = sin m                                                                     
      j:f32[7] = sin i                                                                  v:f32[5] = cos m                                                                     
    in (j, e, g, i, a, b, c, d) }                                                       w:f32[6] = sin n                                                                     
                                                                                        x:f32[6] = cos n                                                                     
                                                                                        y:f32[7] = cos o                                                                     
                                                                                        z:f32[7] = mul t y                                                                   
                                                                                        ba:f32[6] = dot_general[dimension_numbers=(([0], [0]), ([], []))] z r                
                                                                                        bb:f32[6] = mul ba x                                                                 
                                                                                        bc:f32[5] = dot_general[dimension_numbers=(([0], [0]), ([], []))] bb q               
                                                                                        bd:f32[5] = mul bc v                                                                 
                                                                                        be:f32[4] = dot_general[dimension_numbers=(([0], [0]), ([], []))] bd p               
                                                                                        bf:f32[5,4] = dot_general[dimension_numbers=(([], []), ([], []))] bd s               
                                                                                        bg:f32[6,5] = dot_general[dimension_numbers=(([], []), ([], []))] bb u               
                                                                                        bh:f32[7,6] = dot_general[dimension_numbers=(([], []), ([], []))] z w                
                                                                                      in (bf, bg, bh, be) }                                                                  
                                                                                    policy=<function dot_with_no_batch_dims at 0x7f5e469b1700>                               
                                                                                    prevent_cse=True                                                                         
                                                                                  ] a b c d e f g h                                                                          
                                                                                in (i, j, k, l) }                                                                            

让我们一步一步来思考#

你可能想首先(重新)阅读 自动微分食谱 第1部分

jax.checkpoint 的基本知识#

jax.linearizejax.vjp 中,对于某些值的计算方式和时机有灵活性。不同的选择可能会在内存使用和浮点运算(FLOPs)之间进行权衡。JAX 提供了通过 jax.checkpoint 对这些选择的控制。

一个这样的选择是决定在前向传播中尽早计算雅可比系数,还是在反向传播中在系数需要之前进行计算。考虑 sin_vjp 的例子:

def sin_vjp(x):
  y = jnp.sin(x)
  cos_x = jnp.cos(x)
  return y, lambda y_bar: cos_x * y_bar

另一种有效的实现是在反向传播中计算 jnp.cos(x) 的值,而不是在前向传播中计算:

def sin_vjp2(x):
  y = jnp.sin(x)
  return y, lambda y_bar: jnp.cos(x) * y_bar

对于这个特定的函数,两种版本所使用的内存量是相同的,尽管我们减少了原始计算(即前向传播)的FLOPs,并增加了余切计算(即反向传播)的FLOPs。

在函数组合方面还有另一个选择。回想一下我们关于两个函数组合的VJP规则:

def f(x):
  y = g(x)
  z = h(y)
  return z

def f_vjp(x):
  y, g_vjp = jax.vjp(g, x)
  z, h_vjp = jax.vjp(h, y)
  def f_bwd(z_bar):
    y_bar, = h_vjp(z_bar)
    x_bar, = g_vjp(y_bar)
    return x_bar
  return z, f_bwd

另一种选择是:

def f_vjp_checkpoint(x):
  y = g(x)
  z, h_vjp = jax.vjp(h, y)
  def f_bwd2(z_bar):
    y_bar, = h_vjp(z_bar)
    _, g_vjp = jax.vjp(g, x)
    x_bar, = g_vjp(y_bar)
    return x_bar
  return z, f_bwd2

用语言来说,这个替代实现不会在前向传递中计算 g_vjp 或其闭包中的残差值。相反,它只在反向传递 f_bwd2 中计算它们。这意味着 f_vjp_checkpoint 需要更少的内存:如果 gh 对它们的残差各自需要相似数量的内存,且都远大于 x,那么由 f_vjp_checkpoint(x) 产生的函数所需内存是 f_vjp(x) 的一半!

我们所付出的代价是冗余的工作:在 f_bwd2 中,我们必须重新评估 g(x),作为 jax.vjp(g, x) 的一部分,只是为了丢弃它的值(在行 _, g_vjp = jax.vjp(g, x) 的下划线变量中)。

我们可以通过使用 jax.checkpoint 在原始函数 f 的替代定义中获得这种 VJP 行为,而无需直接编写 VJP 函数:

def f_checkpoint(x):
  y = jax.checkpoint(g)(x)
  z = h(y)
  return z

换句话说,我们将 jax.checkpoint 应用于 g,即 f 的第一阶段,而不是直接应用于 f。这样,当我们评估 jax.grad(f_checkpoint)(x) 时,我们会得到如下计算:

  1. 运行 g 的前向传递,丢弃残差值;

  2. 运行 h 的前向传递,保存残差;

  3. 运行 h 的反向传递,消耗步骤 2 中的残差;

  4. 重新运行 g 的前向传递,保存残差;

  5. 运行 g 的反向传递,消耗步骤 4 中的残差。

也就是说,通过评估 jax.grad(f_checkpoint)(x),我们会得到与以下计算相同的结果:

def f_checkpoint_grad(x):
  y = g(x)                  # 步骤 1
  _, h_vjp = jax.vjp(h)(y)  # 第二步
  y_bar, = h_vjp(1.0)       # 第三步
  _, g_vjp = jax.vjp(g, x)  # 步骤四
  x_bar, = g_vjp(y_bar)     # 第五步
  return x_bar

一般来说,jax.checkpoint(foo) 是一个新函数,其输入输出行为与 foo 相同,但在自动微分下表现不同,特别是在 jax.linearizejax.vjp(及其包装器,如 jax.grad)下,然而在 jax.jvp 下则表现不同。当被微分时,只有 jax.checkpoint 区分的函数的输入会在正向传递中存储;在反向传递中,将重新计算残差(即 foo 的中间值及其雅可比系数值,这些值在反向传递中是需要的)。

注意,如果 f = lambda x: h(g(x)) 是我们想要求导的函数,也就是说,如果我们想应用 jax.grad(f),那么对 f 本身应用 jax.checkpoint 并不会带来任何内存节省。这是因为评估 jax.grad(jax.checkpoint(f))(x) 会导致以下计算:

  1. 运行前向传播,丢弃所有残差;

  2. 立即重新运行前向传播,保存残差;

  3. 运行后向传播,消耗步骤 2 中的残差。

也就是说,在代码中我们将会有如下内容:

def f_grad_bad(x):
  _ = f(x)                  # 步骤1
  _, f_vjp = jax.vjp(f, x)  # 第二步
  x_bar, = f_vjp(1.0)       # 第三步
  return x_bar

我们也不会通过对 f 的第二个阶段 h 应用 jax.checkpoint 来获得任何内存节省。这是因为评估 jax.grad(lambda x: jax.checkpoint(h)(g(x))) 将导致以下计算:

  1. 执行 g 的前向传播,保存残差;

  2. 执行 h 的前向传播,丢弃残差;

  3. 立即重新执行 h 的前向传播,保存残差;

  4. 执行 h 的反向传播,消耗步骤 3 的残差;

  5. 执行 g 的反向传播,消耗步骤 1 的残差。

也就是说,在代码中我们会有类似如下的内容:

def f_grad_bad2(x):
  y, g_vjp = jax.vjp(g, x)  # 步骤 1
  z = h(y)                  # 第二步
  _, h_vjp = jax.vjp(h, y)  # 第三步
  y_bar, = h_vjp(1.0)       # 第三步
  x_bar, = g_vjp(y_bar)     # 第五步
  return x_bar

稍微更一般地说,如果我们有一个函数的链式组合,比如f = lambda x: f3(f2(f1(x))),并且我们想评估jax.grad(f),我们可以说:

  • 我们不应该对整个函数f应用jax.checkpoint,因为那样不会节省任何内存(并且会导致不必要的重新计算);

  • 我们不应该对最后一个子函数f3应用jax.checkpoint,因为那样也不会节省任何内存(并且会导致不必要的重新计算);

  • 我们可以对f1f2或它们的组合lambda x: f2(f1(x))应用jax.checkpoint,因为这些可能会节省内存并表达不同的内存/重新计算权衡。

自定义策略以确定可保存内容#

如前所述,使用 jax.checkpoint 可以从一个极端切换到另一个极端:

  • 没有 jax.checkpoint 时,JAX 的自动微分倾向于在前向传播中计算尽可能多的内容并将其存储以备后向传播使用;

  • 使用 jax.checkpoint 装饰器后,我们反而在前向传播中计算尽可能少的内容,而在后向传播中根据需要重新计算值。

为了在这两个极端之间操作,保存某些内容而不保存其他内容,我们可以在子函数上仔细放置 jax.checkpoint 装饰器。但这需要编辑要进行微分的函数,例如模型代码,这可能会不方便。实验不同的变体也可能很困难。

因此,另一种选择是使用 jax.checkpointpolicy 参数。策略是一个可调用对象(即函数),它以一阶原语应用的类型级规范作为输入,并返回一个布尔值,指示相应的输出值是否被允许保存为残差(或者必须在(共)切线计算中根据需要重新计算)。为了编写稳健的代码,策略应从 jax.checkpoint_policies 上的属性中选择,例如 jax.checkpoint_policies.dots_with_no_batch_dims_saveable,因为编写自定义策略可调用对象的 API 被认为是内部的。

例如,考虑要进行微分的这个函数:

def loss(params, x, y):
  return jnp.sum((predict(params, x) - y)**2)

def predict(params, x):
  *Ws, Wlast = params
  for W in Ws:
    x = layer(W, x)
  x = jnp.dot(Wlast, x)
  return x

def layer(W, x):
  return jnp.sin(jnp.dot(W, x))
W1 = W2 = W3 = jnp.ones((4, 4))
params = [W1, W2, W3]
x = jnp.ones(4)
y = jnp.ones(4)
print_saved_residuals(loss, params, x, y)
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4] from the argument 'x'
f32[4] output of sin from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] output of cos from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] output of sin from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] output of cos from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] output of mul from <ipython-input-18-3808b5023c3d>:2 (loss)

在正向传播中,我们可能不需要保存如此多的值,也许我们只想保存没有批次维度的矩阵乘法结果(因为它们可能是受限于FLOP而非内存)。我们可以使用策略 jax.checkpoint_policies.dots_with_no_batch_dims_saveable 来实现这一点:

loss_checkpoint = jax.checkpoint(loss, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
print_saved_residuals(loss_checkpoint, params, x, y)
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4] from the argument 'x'
f32[4] from the argument 'y'
f32[4] output of dot_general from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] output of dot_general from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] output of dot_general from <ipython-input-18-3808b5023c3d>:8 (predict)

请注意,通过提供一个策略,我们不需要编辑定义 losspredictlayer 的代码。如果我们想在调用代码(例如训练脚本)中实验各种策略,而不更改库代码(例如神经网络库),这尤其方便。

某些策略可以引用名为 jax.ad_checkpoint.checkpoint_name 的值:

def predict(params, x):
  *Ws, Wlast = params
  for i, W in enumerate(Ws):
    x = layer(W, x)
    x = checkpoint_name(x, name=f'layer{i}_output')
  x = jnp.dot(Wlast, x)
  return x

checkpoint_name 本身只是一个恒等函数。但由于一些策略函数知道如何查找它们,我们可以使用这些名称来控制由 checkpoint_name 输出的某些值是否被视为可保存的:

print_saved_residuals(loss, params, x, y)
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4] from the argument 'x'
f32[4] output of cos from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] named 'layer0_output' from <ipython-input-22-e48aedf368ad>:7 (predict)
f32[4] output of cos from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] named 'layer1_output' from <ipython-input-22-e48aedf368ad>:7 (predict)
f32[4] output of mul from <ipython-input-18-3808b5023c3d>:2 (loss)
loss_checkpoint2 = jax.checkpoint(loss, policy=jax.checkpoint_policies.save_any_names_but_these('layer1_output'))
print_saved_residuals(loss_checkpoint2, params, x, y)
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4] from the argument 'x'
f32[4] from the argument 'y'

另一个与名称相关的策略是 jax.checkpoint_policies.save_only_these_names

一些策略包括:

  • everything_saveable(默认策略,就像根本没有使用 jax.checkpoint 一样)

  • nothing_saveable(即重新计算所有内容,就像根本没有使用自定义策略一样)

  • dots_saveable 或其别名 checkpoint_dots

  • dots_with_no_batch_dims_saveable 或其别名 checkpoint_dots_with_no_batch_dims

  • save_anything_but_these_names(保存除 checkpoint_name 输出以外的任何值)

  • save_any_names_but_these(仅保存命名值,即 checkpoint_name 的任何输出,但不包括给定名称的输出)

  • save_only_these_names(仅保存命名值,并且仅在给定的名称中)

策略只指明什么是可保存的;只有在向后传播中实际需要时,值才会被保存。

高级: 递归 jax.checkpoint#

通过以正确的方式应用 jax.checkpoint,可以表达出内存使用和(重新)计算之间的多种权衡。一个令人惊讶的例子是_递归_检查点,其中我们将 jax.checkpoint 应用于一个函数,该函数本身以某种方式调用装饰了 jax.checkpoint 的函数,使得 \(D\) 个函数的链式组合的内存使用量按 \(\mathcal{O}(\log_2 D)\) 而不是 \(\mathcal{O}(D)\) 进行扩展。

作为一个玩具示例,考虑多个 jnp.sin 函数的链式组合:

def chain_compose(funs):
  def f(x):
    for fun in funs:
      x = fun(x)
    return x
  return f

f = chain_compose([jnp.sin] * 8)
print_saved_residuals(f, 3.)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)

通常,存储的残差数量与链的长度呈线性比例关系:

f = chain_compose([jnp.sin] * 16)
print_saved_residuals(f, 3.)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)

但我们可以递归地应用 jax.checkpoint 来提高扩展性:

def recursive_checkpoint(funs):
  if len(funs) == 1:
    return funs[0]
  elif len(funs) == 2:
    f1, f2 = funs
    return lambda x: f1(f2(x))
  else:
    f1 = recursive_checkpoint(funs[:len(funs)//2])
    f2 = recursive_checkpoint(funs[len(funs)//2:])
    return lambda x: f1(jax.checkpoint(f2)(x))
f = recursive_checkpoint([jnp.sin] * 8)
print_saved_residuals(f, 3.)
f32[] from the argument 'x'
f32[] output of sin from <ipython-input-27-86f83c871e81>:6 (<lambda>)
f32[] output of cos from <ipython-input-27-86f83c871e81>:6 (<lambda>)
f32[] output of cos from <ipython-input-27-86f83c871e81>:6 (<lambda>)
f = recursive_checkpoint([jnp.sin] * 16)
print_saved_residuals(f, 3.)
f32[] from the argument 'x'
f32[] output of sin from <ipython-input-27-86f83c871e81>:6 (<lambda>)
f32[] output of sin from <ipython-input-27-86f83c871e81>:6 (<lambda>)
f32[] output of cos from <ipython-input-27-86f83c871e81>:6 (<lambda>)
f32[] output of cos from <ipython-input-27-86f83c871e81>:6 (<lambda>)

这里的成本,像往常一样,是重新计算:特别是,我们最终执行的浮点运算次数大约是 \(\mathcal{O}(\log_2 D)\) 倍:

f = chain_compose([jnp.sin] * 8)
print_fwd_bwd(f, 3.)
                                                                                                                                 
  forward computation:                  backward computation:                                                                    
                                                                                                                                 
  { lambda ; a:f32[]. let               { lambda ; a:f32[] b:f32[] c:f32[] d:f32[] e:f32[] f:f32[] g:f32[] h:f32[] i:f32[]. let  
      b:f32[] = sin a                       j:f32[] = mul i a                                                                    
      c:f32[] = cos a                       k:f32[] = mul j b                                                                    
      d:f32[] = sin b                       l:f32[] = mul k c                                                                    
      e:f32[] = cos b                       m:f32[] = mul l d                                                                    
      f:f32[] = sin d                       n:f32[] = mul m e                                                                    
      g:f32[] = cos d                       o:f32[] = mul n f                                                                    
      h:f32[] = sin f                       p:f32[] = mul o g                                                                    
      i:f32[] = cos f                       q:f32[] = mul p h                                                                    
      j:f32[] = sin h                     in (q,) }                                                                              
      k:f32[] = cos h                                                                                                            
      l:f32[] = sin j                                                                                                            
      m:f32[] = cos j                                                                                                            
      n:f32[] = sin l                                                                                                            
      o:f32[] = cos l                                                                                                            
      p:f32[] = sin n                                                                                                            
      q:f32[] = cos n                                                                                                            
    in (p, q, o, m, k, i, g, e, c) }                                                                                             
f = recursive_checkpoint([jnp.sin] * 8)
print_fwd_bwd(f, 3.)
                                                                                                                                        
  forward computation:                                                              backward computation:                               
                                                                                                                                        
  { lambda ; a:f32[]. let                                                           { lambda ; a:f32[] b:f32[] c:f32[] d:f32[]. let     
      b:f32[] = remat2[                                                                 e:f32[] = mul d a                               
        differentiated=False                                                            f:f32[] = mul e b                               
        jaxpr={ lambda ; c:f32[]. let d:f32[] = sin c; e:f32[] = sin d in (e,) }        g:f32[] = remat2[                               
        policy=None                                                                       differentiated=True                           
        prevent_cse=True                                                                  jaxpr={ lambda ; h:f32[] i:f32[]. let         
      ] a                                                                                     j:f32[] = sin h                           
      f:f32[] = sin b                                                                         k:f32[] = cos h                           
      g:f32[] = sin f                                                                         l:f32[] = cos j                           
      h:f32[] = sin g                                                                         m:f32[] = mul i l                         
      i:f32[] = sin h                                                                         n:f32[] = mul m k                         
      j:f32[] = sin i                                                                       in (n,) }                                   
      k:f32[] = cos i                                                                     policy=None                                   
      l:f32[] = sin j                                                                     prevent_cse=True                              
      m:f32[] = cos j                                                                   ] c f                                           
    in (l, m, k, g, a) }                                                                o:f32[] = remat2[                               
                                                                                          differentiated=True                           
                                                                                          jaxpr={ lambda ; p:f32[] q:f32[]. let         
                                                                                              r:f32[] = sin p                           
                                                                                              s:f32[] = sin r                           
                                                                                              t:f32[] = sin s                           
                                                                                              u:f32[] = cos s                           
                                                                                              v:f32[] = cos t                           
                                                                                              w:f32[] = mul q v                         
                                                                                              x:f32[] = mul w u                         
                                                                                              y:f32[] = remat2[                         
                                                                                                differentiated=True                     
                                                                                                jaxpr={ lambda ; z:f32[] ba:f32[]. let  
                                                                                                    bb:f32[] = sin z                    
                                                                                                    bc:f32[] = cos z                    
                                                                                                    bd:f32[] = cos bb                   
                                                                                                    be:f32[] = mul ba bd                
                                                                                                    bf:f32[] = mul be bc                
                                                                                                  in (bf,) }                            
                                                                                                policy=None                             
                                                                                                prevent_cse=True                        
                                                                                              ] p x                                     
                                                                                            in (y,) }                                   
                                                                                          policy=None                                   
                                                                                          prevent_cse=True                              
                                                                                        ] 3.0 g                                         
                                                                                      in (o,) }                                         

实用笔记#

当微分函数被转换为XLA进行编译时,例如通过对包含jax.grad调用的函数应用jax.jit,XLA会自动优化计算,包括何时计算或重新生成值的决策。因此,对于在jax.jit下的微分函数,通常不需要jax.checkpoint。XLA会为你优化这些内容。

一个例外是使用分阶段控制流,比如jax.lax.scan。在多个控制流原语之间的自动编译器优化,例如在前向传递scan与相应的反向传递scan之间,通常没有那么彻底。因此,在传递给jax.lax.scan的主体函数上使用jax.checkpoint通常是个好主意。

例如,在大型Transformer模型中,一个常见的模式是将架构表示为层的jax.lax.scan,以减少编译时间。也就是说,可以使用简单的全连接网络作为类比,而不是编写类似于以下内容:

LayerParam = tuple[jnp.ndarray, jnp.ndarray]  # 层的权重和偏置对
ParamsList = list[LayerParam]

def net(params: ParamsList, x: jnp.ndarray):
  for W, b in params:
    x = jnp.maximum(jnp.dot(x, W) + b, 0.)
  return x

我们将使用 jax.lax.scan 来迭代层的应用:

StackedWeights = jnp.ndarray  # 所有权重矩阵堆叠在一起
StackedBiases = jnp.ndarray   # 所有偏置向量堆叠在一起

all_weights = jnp.stack([W for W, _ in params])
all_biases = jnp.stack([b for _, b in params])

def layer(x, W_b_pair):
  W, b = W_b_pair
  out = jnp.maximum(jnp.dot(x, W) + b, 0.)
  return out, None

def net(all_weights, all_biases, x):
  x, _ = jax.lax.scan(layer, x, (all_weights, all_biases))
  return x

这个扫描层版本减少了编译时间,但由于妨碍了一些编译器的优化,可能导致梯度计算效率低下。为了缓解这个问题,我们将对扫描的函数使用 jax.checkpoint

# 该代码段示例了如何使用`jax`库的检查点功能,以部分应用的形式定义一个层。

## `layer` 函数
`layer` 函数接受输入 `x` 和权重偏置元组 `W_b_pair`,并返回经过激活函数处理的输出。

- `x`:输入数据。
- `W_b_pair`:由权重 `W` 和偏置 `b` 组成的元组。

该函数的输出是:
- `out`:经过线性变换后应用 ReLU 激活函数后的结果。
- `None`:第二个返回值未使用。

## 关键函数
- `jax.checkpoint`:用于保存计算图的一部分,以避免重新计算。
- `jax.checkpoint_policies.dots_with_no_batch_dims_saveable`:指定策略以便在检查点中处理没有批次维度的点积操作。

通过这种方式使用 jax.checkpoint,我们手动控制 JAX 的自动微分在前向传播和反向传播之间保存哪些值,从而不依赖于 XLA 优化为我们选择。