JAX 可变换函数的自定义 JVP/VJP 规则#

这是一份设计文档,解释了 jax.custom_jvpjax.custom_vjp 设计与实现背后的一些思考。面向用户的文档,请参见 教程笔记本

在 JAX 中定义微分规则有两种方式:

  1. 使用 jax.custom_jvpjax.custom_vjp 为已经是 JAX 可转换的 Python 函数定义自定义微分规则;以及

  2. 定义新的 core.Primitive 实例以及它们所有的变换规则,例如调用来自求解器、模拟器或其他通用数值计算系统的函数。

本文档仅关于 #1。

内容#

目标#

我们希望 用户 能够自定义其代码的前向和/或反向模式微分行为。这种自定义

  1. 应该在其工作方式和与其他 JAX 变换组合的方式上具有 清晰且一致的语义;并且

  2. 在支持使用案例和工作流程方面应具有 灵活性 ,如在 AutogradPyTorch 中,包括涉及 Python 控制流微分和 NaN 调试工作流程的案例。

作为 JAX 开发者,我们希望编写库函数,例如 logitexpit,这些函数是根据其他原语定义的,但在微分的目的上具有原语般的行为,即我们希望为它们定义自定义的微分规则,这些规则可能在数值上更稳定或性能更好。特别是,我们不希望为 logitexpit 这样的函数指定 vmapjit 规则。

作为一个延伸目标,我们希望使 JAX 成为一个适合高级用户的环境,他们希望为高阶函数(如 fixed_pointodeint 等)添加自定义微分规则;这个设计文档不会解决那个问题,但我们希望确保我们不会排除解决那个问题的良好方案。

也就是说,我们的主要目标是

  1. 解决 vmap-removes-custom-jvp 语义问题 (#1249),并且

  2. 允许在自定义 VJPs 中使用 Python,例如调试 NaN (#1275)。

次要目标包括 3. 清理和简化用户体验(符号零、关键字参数等) 4. 朝着用户可以轻松添加 fixed_pointodeintroot 等的世界迈进。

总的来说,我们希望关闭 #116, #1097, #1249, #1275, #1366, #1723, #1670, #1875, #1938,并替换自定义变换机制(来自 #636, #818 及其他)。

非目标#

以下是我们打算实现的目标:

  1. custom_transforms 机制旨在提供一个通用的转换机制,用于定制行为,原则上(尽管实际上从未真正使用过)允许用户在某种程度上继承其他转换的“透明”行为的同时,定制任何转换的规则。我们只打算解决微分(JVP 和 VJP,分别)的定制问题。 微分是唯一实际请求的情况,通过专门针对微分,我们可以减少复杂性并提高灵活性。要控制所有规则,只需编写一个原始函数即可。

  2. 我们不会优先考虑数学美学 在用户端的灵活性和清晰度,以及在实现端的简单性之上。特别是,虽然自定义 VJP 签名 a -> (b, CT b --o CT a) 在数学上令人愉悦,但如果由于返回类型中的闭包而在 Python 机制中难以实现,我们很乐意做一些更明确处理残差的事情。

  3. 序列化支持,即可以加载并进一步进行JAX转换的阶段化序列化程序表示形式,而不是仅进行评估,目前不在这些自定义JVP/VJP转换规则的范围内。序列化不仅对希望保存其计算表示(并在加载后进行转换)的研究人员有用,而且对于未来的考虑,例如在Python外部实现jaxpr转换,或将jaxpr作为MLIR方言,也可能有用。通过将此定义为本次设计目的的非目标,我们在何处可以存储Python可调用对象方面有更少的约束。

主要问题描述#

vmap-removes-custom-jvp 语义问题#

vmap-removes-custom-jvp 语义问题在于,vmap 无法正确地与具有 custom_transforms 规则的函数的微分组合:

# old custom_transforms api to be replaced
@jax.custom_transforms
def f(x):
  return 2. * x

# f_vjp :: a -> (b, CT b --o CT a)
def f_vjp(x):
  return f(x), lambda g: 3. * x  # 3 instead of 2

jax.defvjp_all(f, f_vjp)

grad(f)(1.)  # 3.
vmap(grad(f))(np.ones(4))  # [3., 3., 3., 3.]
grad(lambda x: vmap(f)(x).sum())(np.ones(4))  # [2., 2., 2., 2.]

最后一条 grad-of-vmap 行产生了意外的结果!通常,应用 vmap,或者实际上任何非微分变换,都会导致自定义微分规则被移除。(应用 jvp 会在定义自定义 VJP 规则时导致失败。)

问题存在是因为转换类似于重写,而 vmap 转换实际上重写了函数,使其不再调用新引入的、有自定义规则的原语(因此 grad 随后不会产生自定义规则的结果)。更详细地说,custom_transforms 机制设置了一些东西,使得评估 f(x) 时应用该函数

{ lambda  ; ; a.
  let b = f_primitive a
  in [b] }

其中 f_primitive 是一个新的原语(为每个 custom_transforms 函数引入,实际上是为每次函数调用引入),与其关联的是自定义的 VJP 规则。当我们评估 grad(f)(x) 时,微分机制遇到 f_primitive 并使用自定义规则处理它。

然而,因为 f_primitivevmap透明 的,从 vmap 操作(实际上是通过内联)f_primitive 的定义来看,函数 vmap(f) 实际上是

{ lambda  ; ; a.
  let b = mul 2. a
  in [b] }

换句话说,vmap 根据其底层原语及其转换规则重写函数,完全移除 f_primitive

更一般地说,因为 vmap(f) 的语义是根据对 f 的调用来定义的,所以移除自定义导数规则在语义上是不一致的。也就是说,由于我们定义

vmap(f)(xs) == np.stack([f(x) for x in xs])

我们必须有

jvp(vmap(f))(xs) == jvp(lambda xs: np.stack([f(x) for x in xs]))

然而,当 f 定义了自定义导数规则时,不会观察到此属性,因为自定义导数规则用于右侧版本,而不是左侧版本。

这个问题并不是特定于 vmap 的;它适用于所有变换,这些变换的语义是通过对函数 f 的调用来定义的,而不是将其重写为另一个函数。mask 变换也属于这一类。微分变换和假设的所有一元函数变为余弦函数的变换不属于这一类。

(像自定义 vmap 规则这样的额外自定义规则之间的交互可能会变得更加复杂,这表明 custom_transforms 的问题框架过于宽泛。)

Python 的灵活性问题#

在 JAX 中,与 AutogradPyTorch 类似,但与 TF1 不同,Python 函数的微分是在函数执行和追踪时进行的。这种行为让用户感到高兴,原因有几个。

首先也是最重要的是,它支持基于pdb的工作流程,例如用于检查数值或捕获NaN。 也就是说,用户可以使用标准的Python调试器和其他Python原生工具来调试他们的代码,甚至能够检查运行时值以理解数值行为并捕获像NaN这样的基本运行时错误。事实上,就在我为这个设计对应的PR工作时,特别是在处理odeint原语时,我多次使用运行时值检查来调试问题,这增加了我对这是Python中关键用户工作流程的信心。一个特别方便的技巧,我在JAX和Autograd中多次使用过,就是在自定义VJP规则中插入一个调试器断点,以便在反向传播的特定点进入调试器。

第二,它允许区分Python原生控制流。 我们不确定这在实际的最终软件制品中使用频率如何,但当用户初次接触JAX或Autograd时,他们常常对这种自由印象深刻。我们将其置于JAX和Autograd的README、幻灯片和演示文稿的顶部是有原因的。放弃这种能力将是Autograd的一个倒退。我们希望JAX拥有最好的自动微分功能。

然而,custom_transforms 机制并不提供这种 Python 支持的灵活性。这是因为它是通过从用户函数和自定义微分规则的 Python 代码中提前形成 jaxpr 来实现的,这种代码会导致抽象值追踪错误:

# old custom_transforms api to be replaced
@jax.custom_transforms
def f(x):
  if x > 0:
    return x
  else:
    return 0.

def f_vjp(x):
  return ...

jax.defvjp_all(f, f_vjp)

grad(f)(1.)  # Error!

解决方案思路#

主要思想是 dougalm@ 已经通过 core.call 解决了这些问题。也就是说,我们可以将为用户函数指定自定义 JVP 规则的任务框架化为一个新的 Python 级调用原语(不要添加到 jaxpr 语言中;见下文)。这个新的调用原语与 core.call 一样,关联了一个用户 Python 函数,但另外还有一个表示 JVP 规则的第二个 Python 可调用对象。我们称这个新的调用原语为 custom_jvp_call

vmap 这样的变换与 custom_jvp_call 的交互方式与 core.call 相同:它们实际上直接通过它,并应用于底层 Python 可调用对象。从原理上讲,为了方便起见,用原语的柯里化版本书写,类似于 vmap 通过应用于要调用的函数来与 core.call 交互的方式:

vmap(call(f)) == call(vmap(f))

对于新的原语 custom_jvp_call,我们只需将 vmap 应用于它所包含的两个函数:

vmap(custom_jvp_call(f, f_jvp)) == custom_jvp_call(vmap(f), vmap(f_jvp))

此行为意味着我们已经解决了 vmap-removes-custom-jvp 语义问题

jvp 变换的交互正如人们所预期的那样:它只是调用了 f_jvp

jvp(call(f)) == call(jvp(f))

jvp(custom_jvp_call(f, f_jvp)) == f_jvp

因为 custom_jvp_call 的行为类似于 core.call(而不是 xla.xla_call),它不会提升其输入的抽象级别(因为它没有延迟任何东西或分阶段执行任何东西),这意味着我们已经解决了 Python 灵活性问题:用户 Python 函数没有任何限制(除了 jvpvjp 通常所需的功能编程约束)。

关于评估和编译呢?这是两种“退出”JAX系统的方式,因为在这两个步骤之后,不能再应用额外的转换。因此,它们的规则是微不足道的:

eval(call(f)) == eval(f)
jit(call(f)) == hlo_call(jit(f))

eval(custom_jvp_call(f, f_jvp)) == eval(f)
jit(custom_jvp_call(f, f_jvp)) == hlo_call(jit(f))

换句话说,如果一个 JVP 规则还没有将 custom_jvp_call(f, f_jvp) 重写为 f_jvp,当我们使用 eval 进行评估或使用 jit 编排到 XLA 时,微分永远不会被应用,所以我们只是忽略 f_jvp 并表现得像 core.call。然而,由于接下来讨论的复杂性,custom_jvp_call 的部分评估规则必须更加复杂,因为部分评估不仅仅用于使用 jit 编排到 XLA。

唯一剩下的问题与“初始风格”的 jaxpr 形成原语有关,比如 lax.scan,以及它们的转换规则。这些代表了与编译不同的“分阶段输出到 jaxpr”,因为我们可以对分阶段输出的 jaxpr 执行额外的转换。也就是说,当 lax.scan 形成一个 jaxpr 时,它不会退出转换系统,因为当我们对 lax.scan 应用 jvp 或 vmap 时,我们需要将其应用于 jaxpr 所表示的函数。

另一种描述这种复杂性的方式是,像 lax.scan 这样的初始样式原语依赖于能够在保留语义的情况下往返于 jaxpr 并返回 Python 可调用对象的能力。这也必须意味着保留自定义微分规则的语义。

解决方案是使用一点动态作用域:当我们为初始样式的原语(如 lax_control_flow.py 中的那些)进行 jaxpr 的暂存时,我们在全局跟踪状态上设置一个标志。当该标志被设置时,我们不使用最终样式的 custom_jvp_call 原语,而是使用初始样式的 custom_jvp_call_jaxpr 原语,并预先将函数 ff_jvp 跟踪到 jaxpr 中,以使初始样式的处理更容易。custom_jvp_call_jaxpr 原语在其他方面与最终样式的版本类似。

(脚注:虽然从道德上讲,我们在绑定 custom_jvp_call_jaxpr 之前为 ff_jvp 都形成了 jaxprs,但我们需要延迟 f_jvp 的 jaxpr 的形成,因为它可能会调用自定义的 JVP 函数,因此急切处理会导致无限递归。我们在 thunk 中延迟了该 jaxpr 的形成。)

如果我们放弃 Python 灵活性问题,我们就可以只使用 custom_jvp_call_jaxpr 而不用单独的 Python 级原语 custom_jvp_call

API#

a -> b 函数的自定义 JVP 是通过 (a, Ta) -> (b, T b) 函数指定的:

# f :: a -> b
@jax.custom_jvp
def f(x):
  return np.sin(x)

# f_jvp :: (a, T a) -> (b, T b)
def f_jvp(primals, tangents):
  x, = primals
  t, = tangents
  return f(x), np.cos(x) * t

f.defjvp(f_jvp)

(有趣的自动微分旁注:为了使规则适用于高阶微分,必须在 f_jvp 的主体中调用 f;这排除了 f 的内部与切线计算之间某些类型的工作共享。)

对于 a -> b 函数的自定义 VJP,是通过一个 a -> (b, c) 的前向传递函数与一个 (c, CT b) -> CT 的后向传递函数配对来指定的:

# f :: a -> b
@jax.custom_vjp
def f(x):
  return np.sin(x)

# f_fwd :: a -> (b, c)
def f_fwd(x):
  return f(x), np.cos(x)

# f_bwd :: (c, CT b) -> CT a
def f_bwd(cos_x, g):
  return (cos_x * g,)

f.defvjp(f_fwd, f_bwd)

签名 a -> (b, CT b --o CT a) 在美学上更令人愉悦,但支持它会使实现更加复杂,并且可能需要妥协表达性需求。Python 可调用对象的基本原因是它们是不透明的(除非我们急切地将它们追踪到一个 jaxpr,这会限制表达性),在这种情况下,我们可能会返回一个闭包中包含 vmap 追踪器的可调用对象,我们需要在正向传递过程中了解这些信息。

我们可以添加便利的包装器,例如一次定义单个参数的 JVP 规则(就像我们在内部为原语所做的那样)。但由于这个提案本身已经足够复杂,我决定不增加便利层;现在让我们保持最小化。

API 还有一些其他的功能:

  • 输入和输出类型 a, b, 和 c 可以是任意 jaxtypes 的 pytrees。

  • 当可以使用 inspect 模块解析为位置时,支持按名称(关键字参数)传递参数。这是对 Python 3 改进的程序化检查参数签名能力的实验。我认为这是合理的但并不完善,这是一个不错的起点。(另见 #2069。)

  • 可以使用 nondiff_argnums 标记不可微分的参数,并且与 jitstatic_argnums 一样,这些参数不必是 JAX 类型。我们需要为这些参数如何传递给规则设定一个约定。对于类型签名为 (d, a) -> b 的原函数,其中 d 表示不可微分的类型,JVP 规则的签名是 (a, T a, d) -> T b,而 VJP 规则的反向组件签名是 (d, c, CT b) -> CT a。也就是说,在自定义 JVP 规则中,不可微分的参数按顺序在 primalstangents 之后传递,而在自定义 VJP 规则的反向函数中,它们按顺序在残差之前传递。

实现说明#

  • 更新 jax.experimental.odeint

    • 由于 odeint 是一个相当复杂的自定义 VJP 规则用户,除了仅仅更新它以使其正常工作之外,我还想将其修订为新自定义 VJP API 的规范用户,以此来测试该 API 是否良好。

    • 在此过程中,我对 odeint 实现进行了其他改进:

      • 移除展开/解开样板代码

      • 使用 lax.scan 来移除索引更新逻辑

      • 在简单摆的基准测试中,速度提升了20%以上

  • 为每个变换添加了一个自定义绑定方法,用于自定义派生调用原语,custom_jvp_callcustom_vjp_call。它类似于 core.call_bind,除了我们不处理环境跟踪:那些只是错误。

  • 添加了 custom_lin 原语,该原语在自定义 VJP 规则使用时会被分阶段输出到线性 jaxprs 中以进行转置。

    • 因为我们的反向模式自动微分被分解为线性化、部分求值和转置,我们的自定义VJP规则在两个独立的步骤中处理:一个在线性化期间,一个在转置期间。

    • 线性化步骤,即 custom_vjp_call 的 JVP 规则,将 custom_lin 应用于切线值;custom_lin 携带着用户的自定义反向传递函数,并且作为一个原语,它只有一个转置规则。

    • 此机制在 #636 中有更详细的描述。

  • 为了防止