JAX 可变换函数的自定义 JVP/VJP 规则#
这是一份设计文档,解释了 jax.custom_jvp
和 jax.custom_vjp
设计与实现背后的一些思考。面向用户的文档,请参见 教程笔记本。
在 JAX 中定义微分规则有两种方式:
使用
jax.custom_jvp
和jax.custom_vjp
为已经是 JAX 可转换的 Python 函数定义自定义微分规则;以及定义新的
core.Primitive
实例以及它们所有的变换规则,例如调用来自求解器、模拟器或其他通用数值计算系统的函数。
本文档仅关于 #1。
内容#
目标#
我们希望 用户 能够自定义其代码的前向和/或反向模式微分行为。这种自定义
应该在其工作方式和与其他 JAX 变换组合的方式上具有 清晰且一致的语义;并且
在支持使用案例和工作流程方面应具有 灵活性 ,如在 Autograd 和 PyTorch 中,包括涉及 Python 控制流微分和 NaN 调试工作流程的案例。
作为 JAX 开发者,我们希望编写库函数,例如 logit
和 expit
,这些函数是根据其他原语定义的,但在微分的目的上具有原语般的行为,即我们希望为它们定义自定义的微分规则,这些规则可能在数值上更稳定或性能更好。特别是,我们不希望为 logit
和 expit
这样的函数指定 vmap
或 jit
规则。
作为一个延伸目标,我们希望使 JAX 成为一个适合高级用户的环境,他们希望为高阶函数(如 fixed_point
、odeint
等)添加自定义微分规则;这个设计文档不会解决那个问题,但我们希望确保我们不会排除解决那个问题的良好方案。
也就是说,我们的主要目标是
次要目标包括 3. 清理和简化用户体验(符号零、关键字参数等) 4. 朝着用户可以轻松添加 fixed_point
、odeint
、root
等的世界迈进。
总的来说,我们希望关闭 #116, #1097, #1249, #1275, #1366, #1723, #1670, #1875, #1938,并替换自定义变换机制(来自 #636, #818 及其他)。
非目标#
以下是我们不打算实现的目标:
custom_transforms
机制旨在提供一个通用的转换机制,用于定制行为,原则上(尽管实际上从未真正使用过)允许用户在某种程度上继承其他转换的“透明”行为的同时,定制任何转换的规则。我们只打算解决微分(JVP 和 VJP,分别)的定制问题。 微分是唯一实际请求的情况,通过专门针对微分,我们可以减少复杂性并提高灵活性。要控制所有规则,只需编写一个原始函数即可。我们不会优先考虑数学美学 在用户端的灵活性和清晰度,以及在实现端的简单性之上。特别是,虽然自定义 VJP 签名
a -> (b, CT b --o CT a)
在数学上令人愉悦,但如果由于返回类型中的闭包而在 Python 机制中难以实现,我们很乐意做一些更明确处理残差的事情。序列化支持,即可以加载并进一步进行JAX转换的阶段化序列化程序表示形式,而不是仅进行评估,目前不在这些自定义JVP/VJP转换规则的范围内。序列化不仅对希望保存其计算表示(并在加载后进行转换)的研究人员有用,而且对于未来的考虑,例如在Python外部实现jaxpr转换,或将jaxpr作为MLIR方言,也可能有用。通过将此定义为本次设计目的的非目标,我们在何处可以存储Python可调用对象方面有更少的约束。
主要问题描述#
vmap-removes-custom-jvp 语义问题#
vmap-removes-custom-jvp 语义问题在于,vmap 无法正确地与具有 custom_transforms
规则的函数的微分组合:
# old custom_transforms api to be replaced
@jax.custom_transforms
def f(x):
return 2. * x
# f_vjp :: a -> (b, CT b --o CT a)
def f_vjp(x):
return f(x), lambda g: 3. * x # 3 instead of 2
jax.defvjp_all(f, f_vjp)
grad(f)(1.) # 3.
vmap(grad(f))(np.ones(4)) # [3., 3., 3., 3.]
grad(lambda x: vmap(f)(x).sum())(np.ones(4)) # [2., 2., 2., 2.]
最后一条 grad-of-vmap 行产生了意外的结果!通常,应用 vmap
,或者实际上任何非微分变换,都会导致自定义微分规则被移除。(应用 jvp
会在定义自定义 VJP 规则时导致失败。)
问题存在是因为转换类似于重写,而 vmap
转换实际上重写了函数,使其不再调用新引入的、有自定义规则的原语(因此 grad
随后不会产生自定义规则的结果)。更详细地说,custom_transforms
机制设置了一些东西,使得评估 f(x)
时应用该函数
{ lambda ; ; a.
let b = f_primitive a
in [b] }
其中 f_primitive
是一个新的原语(为每个 custom_transforms
函数引入,实际上是为每次函数调用引入),与其关联的是自定义的 VJP 规则。当我们评估 grad(f)(x)
时,微分机制遇到 f_primitive
并使用自定义规则处理它。
然而,因为 f_primitive
对 vmap
是 透明 的,从 vmap
操作(实际上是通过内联)f_primitive
的定义来看,函数 vmap(f)
实际上是
{ lambda ; ; a.
let b = mul 2. a
in [b] }
换句话说,vmap
根据其底层原语及其转换规则重写函数,完全移除 f_primitive
。
更一般地说,因为 vmap(f)
的语义是根据对 f 的调用来定义的,所以移除自定义导数规则在语义上是不一致的。也就是说,由于我们定义
vmap(f)(xs) == np.stack([f(x) for x in xs])
我们必须有
jvp(vmap(f))(xs) == jvp(lambda xs: np.stack([f(x) for x in xs]))
然而,当 f
定义了自定义导数规则时,不会观察到此属性,因为自定义导数规则用于右侧版本,而不是左侧版本。
这个问题并不是特定于 vmap
的;它适用于所有变换,这些变换的语义是通过对函数 f
的调用来定义的,而不是将其重写为另一个函数。mask
变换也属于这一类。微分变换和假设的所有一元函数变为余弦函数的变换不属于这一类。
(像自定义 vmap
规则这样的额外自定义规则之间的交互可能会变得更加复杂,这表明 custom_transforms
的问题框架过于宽泛。)
Python 的灵活性问题#
在 JAX 中,与 Autograd 和 PyTorch 类似,但与 TF1 不同,Python 函数的微分是在函数执行和追踪时进行的。这种行为让用户感到高兴,原因有几个。
首先也是最重要的是,它支持基于pdb的工作流程,例如用于检查数值或捕获NaN。 也就是说,用户可以使用标准的Python调试器和其他Python原生工具来调试他们的代码,甚至能够检查运行时值以理解数值行为并捕获像NaN这样的基本运行时错误。事实上,就在我为这个设计对应的PR工作时,特别是在处理odeint
原语时,我多次使用运行时值检查来调试问题,这增加了我对这是Python中关键用户工作流程的信心。一个特别方便的技巧,我在JAX和Autograd中多次使用过,就是在自定义VJP规则中插入一个调试器断点,以便在反向传播的特定点进入调试器。
第二,它允许区分Python原生控制流。 我们不确定这在实际的最终软件制品中使用频率如何,但当用户初次接触JAX或Autograd时,他们常常对这种自由印象深刻。我们将其置于JAX和Autograd的README、幻灯片和演示文稿的顶部是有原因的。放弃这种能力将是Autograd的一个倒退。我们希望JAX拥有最好的自动微分功能。
然而,custom_transforms
机制并不提供这种 Python 支持的灵活性。这是因为它是通过从用户函数和自定义微分规则的 Python 代码中提前形成 jaxpr 来实现的,这种代码会导致抽象值追踪错误:
# old custom_transforms api to be replaced
@jax.custom_transforms
def f(x):
if x > 0:
return x
else:
return 0.
def f_vjp(x):
return ...
jax.defvjp_all(f, f_vjp)
grad(f)(1.) # Error!
解决方案思路#
主要思想是 dougalm@ 已经通过 core.call
解决了这些问题。也就是说,我们可以将为用户函数指定自定义 JVP 规则的任务框架化为一个新的 Python 级调用原语(不要添加到 jaxpr 语言中;见下文)。这个新的调用原语与 core.call
一样,关联了一个用户 Python 函数,但另外还有一个表示 JVP 规则的第二个 Python 可调用对象。我们称这个新的调用原语为 custom_jvp_call
。
像 vmap
这样的变换与 custom_jvp_call
的交互方式与 core.call
相同:它们实际上直接通过它,并应用于底层 Python 可调用对象。从原理上讲,为了方便起见,用原语的柯里化版本书写,类似于 vmap
通过应用于要调用的函数来与 core.call
交互的方式:
vmap(call(f)) == call(vmap(f))
对于新的原语 custom_jvp_call
,我们只需将 vmap
应用于它所包含的两个函数:
vmap(custom_jvp_call(f, f_jvp)) == custom_jvp_call(vmap(f), vmap(f_jvp))
此行为意味着我们已经解决了 vmap-removes-custom-jvp 语义问题。
jvp
变换的交互正如人们所预期的那样:它只是调用了 f_jvp
。
jvp(call(f)) == call(jvp(f))
jvp(custom_jvp_call(f, f_jvp)) == f_jvp
因为 custom_jvp_call
的行为类似于 core.call
(而不是 xla.xla_call
),它不会提升其输入的抽象级别(因为它没有延迟任何东西或分阶段执行任何东西),这意味着我们已经解决了 Python 灵活性问题:用户 Python 函数没有任何限制(除了 jvp
或 vjp
通常所需的功能编程约束)。
关于评估和编译呢?这是两种“退出”JAX系统的方式,因为在这两个步骤之后,不能再应用额外的转换。因此,它们的规则是微不足道的:
eval(call(f)) == eval(f)
jit(call(f)) == hlo_call(jit(f))
eval(custom_jvp_call(f, f_jvp)) == eval(f)
jit(custom_jvp_call(f, f_jvp)) == hlo_call(jit(f))
换句话说,如果一个 JVP 规则还没有将 custom_jvp_call(f, f_jvp)
重写为 f_jvp
,当我们使用 eval
进行评估或使用 jit
编排到 XLA 时,微分永远不会被应用,所以我们只是忽略 f_jvp
并表现得像 core.call
。然而,由于接下来讨论的复杂性,custom_jvp_call
的部分评估规则必须更加复杂,因为部分评估不仅仅用于使用 jit
编排到 XLA。
唯一剩下的问题与“初始风格”的 jaxpr 形成原语有关,比如 lax.scan
,以及它们的转换规则。这些代表了与编译不同的“分阶段输出到 jaxpr”,因为我们可以对分阶段输出的 jaxpr 执行额外的转换。也就是说,当 lax.scan
形成一个 jaxpr 时,它不会退出转换系统,因为当我们对 lax.scan
应用 jvp 或 vmap 时,我们需要将其应用于 jaxpr 所表示的函数。
另一种描述这种复杂性的方式是,像 lax.scan
这样的初始样式原语依赖于能够在保留语义的情况下往返于 jaxpr 并返回 Python 可调用对象的能力。这也必须意味着保留自定义微分规则的语义。
解决方案是使用一点动态作用域:当我们为初始样式的原语(如 lax_control_flow.py 中的那些)进行 jaxpr 的暂存时,我们在全局跟踪状态上设置一个标志。当该标志被设置时,我们不使用最终样式的 custom_jvp_call
原语,而是使用初始样式的 custom_jvp_call_jaxpr
原语,并预先将函数 f
和 f_jvp
跟踪到 jaxpr 中,以使初始样式的处理更容易。custom_jvp_call_jaxpr
原语在其他方面与最终样式的版本类似。
(脚注:虽然从道德上讲,我们在绑定 custom_jvp_call_jaxpr
之前为 f
和 f_jvp
都形成了 jaxprs,但我们需要延迟 f_jvp
的 jaxpr 的形成,因为它可能会调用自定义的 JVP 函数,因此急切处理会导致无限递归。我们在 thunk 中延迟了该 jaxpr 的形成。)
如果我们放弃 Python 灵活性问题,我们就可以只使用 custom_jvp_call_jaxpr
而不用单独的 Python 级原语 custom_jvp_call
。
API#
a -> b
函数的自定义 JVP 是通过 (a, Ta) -> (b, T b)
函数指定的:
# f :: a -> b
@jax.custom_jvp
def f(x):
return np.sin(x)
# f_jvp :: (a, T a) -> (b, T b)
def f_jvp(primals, tangents):
x, = primals
t, = tangents
return f(x), np.cos(x) * t
f.defjvp(f_jvp)
(有趣的自动微分旁注:为了使规则适用于高阶微分,必须在 f_jvp
的主体中调用 f
;这排除了 f
的内部与切线计算之间某些类型的工作共享。)
对于 a -> b
函数的自定义 VJP,是通过一个 a -> (b, c)
的前向传递函数与一个 (c, CT b) -> CT
的后向传递函数配对来指定的:
# f :: a -> b
@jax.custom_vjp
def f(x):
return np.sin(x)
# f_fwd :: a -> (b, c)
def f_fwd(x):
return f(x), np.cos(x)
# f_bwd :: (c, CT b) -> CT a
def f_bwd(cos_x, g):
return (cos_x * g,)
f.defvjp(f_fwd, f_bwd)
签名 a -> (b, CT b --o CT a)
在美学上更令人愉悦,但支持它会使实现更加复杂,并且可能需要妥协表达性需求。Python 可调用对象的基本原因是它们是不透明的(除非我们急切地将它们追踪到一个 jaxpr,这会限制表达性),在这种情况下,我们可能会返回一个闭包中包含 vmap
追踪器的可调用对象,我们需要在正向传递过程中了解这些信息。
我们可以添加便利的包装器,例如一次定义单个参数的 JVP 规则(就像我们在内部为原语所做的那样)。但由于这个提案本身已经足够复杂,我决定不增加便利层;现在让我们保持最小化。
API 还有一些其他的功能:
输入和输出类型
a
,b
, 和c
可以是任意 jaxtypes 的 pytrees。当可以使用
inspect
模块解析为位置时,支持按名称(关键字参数)传递参数。这是对 Python 3 改进的程序化检查参数签名能力的实验。我认为这是合理的但并不完善,这是一个不错的起点。(另见 #2069。)可以使用
nondiff_argnums
标记不可微分的参数,并且与jit
的static_argnums
一样,这些参数不必是 JAX 类型。我们需要为这些参数如何传递给规则设定一个约定。对于类型签名为(d, a) -> b
的原函数,其中d
表示不可微分的类型,JVP 规则的签名是(a, T a, d) -> T b
,而 VJP 规则的反向组件签名是(d, c, CT b) -> CT a
。也就是说,在自定义 JVP 规则中,不可微分的参数按顺序在primals
和tangents
之后传递,而在自定义 VJP 规则的反向函数中,它们按顺序在残差之前传递。
实现说明#
更新
jax.experimental.odeint
由于
odeint
是一个相当复杂的自定义 VJP 规则用户,除了仅仅更新它以使其正常工作之外,我还想将其修订为新自定义 VJP API 的规范用户,以此来测试该 API 是否良好。在此过程中,我对
odeint
实现进行了其他改进:移除展开/解开样板代码
使用
lax.scan
来移除索引更新逻辑在简单摆的基准测试中,速度提升了20%以上
为每个变换添加了一个自定义绑定方法,用于自定义派生调用原语,
custom_jvp_call
和custom_vjp_call
。它类似于core.call_bind
,除了我们不处理环境跟踪:那些只是错误。添加了
custom_lin
原语,该原语在自定义 VJP 规则使用时会被分阶段输出到线性 jaxprs 中以进行转置。因为我们的反向模式自动微分被分解为线性化、部分求值和转置,我们的自定义VJP规则在两个独立的步骤中处理:一个在线性化期间,一个在转置期间。
线性化步骤,即
custom_vjp_call
的 JVP 规则,将custom_lin
应用于切线值;custom_lin
携带着用户的自定义反向传递函数,并且作为一个原语,它只有一个转置规则。此机制在 #636 中有更详细的描述。
为了防止