jax.remat / jax.checkpoint 更改:你需要知道的内容#

内容#

发生了什么?#

#11830 起,我们将启用 jax.checkpoint() 的新实现,即 jax.remat()(这两个名称互为别名)。对于大多数代码,不会有任何变化。 但在某些边缘情况下可能会有一些可观察到的差异;请参阅 升级后可能出现的问题是什么?

我如何禁用更改,并暂时恢复到旧的行为?#

如果你对这次更改有任何问题,通过版本 jax==0.3.16 你可以通过将 jax_new_checkpoint 配置选项设置为 False 来关闭新实现,可以通过以下任何一种方式进行设置:

  1. 设置shell环境变量 JAX_NEW_CHECKPOINT=0

  2. 执行 jax.config.update('jax_new_checkpoint', False);

  3. 如果你使用 absl 解析标志,请传递 --jax_new_checkpoint=False 选项。

如果你需要恢复到旧的实现,请在GitHub问题中联系我们,以便我们能让新的实现为你工作。

截至 jax==0.3.17jax_new_checkpoint 配置选项已不再可用。如果您遇到问题,请在 问题跟踪器 上联系我们,以便我们帮助修复!

我们为什么要这样做?#

在撰写本文时,JAX 有两个 jax.checkpoint 的并行实现。新的实现已经使用了数月(例如,由 Pax 和 Flaxformer/T5X 使用),但它是基于选择加入的,尚未默认启用。

我们希望将新的实现切换为默认启用,然后删除旧的实现。使用新的实现并移除旧的实现,为用户带来几个好处。

用户可自定义的重计算策略#

新实现的主要优势是新增了与 policy 参数对应的功能。其理念是让用户在自动微分的前向传播过程中,精确控制哪些中间结果被保存(而不是重新计算)。通过控制内存使用与重新计算之间的权衡,用户可以获得显著的性能提升,特别是在大型模型和我们的 LLM MLPerf 提交中!

此功能的完整文档仍在编写中,但这里有一个快速示例:

from functools import partial
import jax

def apply_layer(W, x):
  return jnp.sin(jnp.dot(W, x))

@partial(jax.checkpoint, policy=jax.checkpoint_policies.checkpoint_dots)
def predict(params, x):
  for W in params[:-1]:
    x = apply_layer(W, x)
  return jnp.dot(params[-1], x)

通过在此处应用 jax.checkpoint 并设置 policy=jax.checkpoint_policies.checkpoint_dots,我们确保在正向传递期间只保存矩阵乘法的结果。来自 cos 应用的雅可比系数值,以及计算它们所需的 sin 应用的值,不会在正向传递中保存,而是在反向传递期间重新计算。(这种策略在 TPU 上可能很有效,其中元素级计算实际上是免费的,但矩阵单元的计算结果值得保存。)

能够重新实例化常量,而不仅仅是依赖于参数的数据操作#

旧的 jax.checkpoint 实现实际上无法在没有对装饰函数参数的数据依赖的情况下重新生成计算。考虑这个简单的例子:

@jax.checkpoint
def f(x):
  a = some_function(jnp.arange(10_000_000))  # `a` does not depend on `x`
  return a * x

旧的 jax.checkpoint 实现被迫保存 a 的值,这可能需要大量内存。新的 jax.checkpoint 实现可以重新生成 a 的值,而不是保存它。

在某些情况下,Python 开销显著减少#

新的 jax.checkpoint 在某些情况下显著减少了 Python 开销。简单的开销基准测试 速度提升了 10 倍。这些开销仅在急切逐操作执行时出现,因此在常用情况下,在 jax.jit 或类似情况下使用 jax.checkpoint 时,速度提升并不相关。但仍然很不错!

通过简化内部机制启用新的 JAX 功能#

这一改变也为未来的用户带来了巨大的好处,例如自定义批处理规则(vmapcustom_vjp 类似物)和 custom_vjp 的前向可微升级。它还显著降低了 JAX 代码库部分的复杂性,这将有利于维护和一般性的错误修复。

升级后可能出现哪些问题?#

无害的数值变化#

由于新的实现可以重新计算更多的计算,包括那些可能涉及大常数的计算,一些代码可能会看到小的数值变化。任何数值变化的大小都应该在我们从改变编译器优化(如浮点操作的重排序)中预期的范围内。但是,一些过于严格的测试容差可能需要稍微放宽。

concrete=True 选项已被移除。#

旧的 jax.checkpoint 实现有一个布尔值 concrete 选项,该选项允许在具体的 Python 值上进行追踪(而不是延迟所有计算并在抽象值上进行追踪)。该选项很少使用,并且在使用它的场合中,有更简单的替代方案。因此,我们在新的 jax.checkpoint 中移除了该选项。

例如,在Google代码中广泛使用的 concrete=True 是为了支持传递像 is_training 这样的参数:

@partial(jax.checkpoint, concrete=True)  # OLD jax.checkpoint API
def foo(x, is_training):
  if is_training:
    return g(x)
  else:
    return h(x)

通过新的 jax.checkpoint 实现,我们可以使用 static_argnums 选项来实现相同的功能:

@partial(jax.checkpoint, static_argnums=(1,))  # NEW jax.checkpoint API
def foo(x, is_training):
  if is_training:
    ...

如果需要在静态参数上执行 jax.numpy 操作,并在 Python 追踪期间计算其数值结果而不是延迟计算,我们可以结合使用 static_argnumsjax.ensure_compile_time_eval()。但看起来你不太可能需要这个!