custom_vjp 和 nondiff_argnums 更新指南

custom_vjpnondiff_argnums 更新指南#

mattjj@ 2020年10月14日

本文假设读者熟悉 jax.custom_vjp,如 JAX 可转换 Python 函数的自定义导数规则 笔记本中所述。

更新什么#

在 JAX PR #4008 之后,传递给 custom_vjp 函数的 nondiff_argnums 不能是 Tracer(或 Tracer 的容器),这基本上意味着为了允许代码进行任意变换,nondiff_argnums 不应用于数组值的参数。相反,nondiff_argnums 应该仅用于非数组值,如 Python 可调用对象、形状元组或字符串。

无论我们过去如何使用 nondiff_argnums 来处理数组值,我们都应该直接将它们作为常规参数传递。在 bwd 规则中,我们需要为它们生成值,但我们可以直接生成 None 值来表示没有对应的梯度值。

例如,这里展示了编写 clip_gradient 方法,当 hi 和/或 lo 是来自某些 JAX 变换的 Tracer 时,这种方法将无法工作。

from functools import partial
import jax

@partial(jax.custom_vjp, nondiff_argnums=(0, 1))
def clip_gradient(lo, hi, x):
  return x  # identity function

def clip_gradient_fwd(lo, hi, x):
  return x, None  # no residual values to save

def clip_gradient_bwd(lo, hi, _, g):
  return (jnp.clip(g, lo, hi),)

clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd)

这是的、令人惊叹的方式,支持任意变换:

import jax

@jax.custom_vjp  # no nondiff_argnums!
def clip_gradient(lo, hi, x):
  return x  # identity function

def clip_gradient_fwd(lo, hi, x):
  return x, (lo, hi)  # save lo and hi values as residuals

def clip_gradient_bwd(res, g):
  lo, hi = res
  return (None, None, jnp.clip(g, lo, hi))  # return None for lo and hi

clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd)

如果你使用旧方法而不是新方法,在任何可能出错的情况下(即当有 Tracer 传递给 nondiff_argnums 参数时),你都会得到一个明显的错误。

这里有一个实际需要使用 nondiff_argnumscustom_vjp 的情况:

from functools import partial
import jax

@partial(jax.custom_vjp, nondiff_argnums=(0,))
def skip_app(f, x):
  return f(x)

def skip_app_fwd(f, x):
  return skip_app(f, x), None

def skip_app_bwd(f, _, g):
  return (g,)

skip_app.defvjp(skip_app_fwd, skip_app_bwd)

解释#

Tracer 传递到 nondiff_argnums 参数中一直存在问题。虽然有些情况下可以正确工作,但其他情况会导致复杂且令人困惑的错误信息。

这个bug的本质在于 nondiff_argnums 的实现方式非常类似于词法闭包。但当时词法闭包在 Tracer 上并不打算与 custom_jvp/custom_vjp 一起工作。以这种方式实现 nondiff_argnums 是一个错误!

PR #4008 修复了 custom_jvpcustom_vjp 的所有词法闭包问题。 太棒了!也就是说,现在 custom_jvpcustom_vjp 函数和规则可以随心所欲地闭包 Tracer。对于所有非自动微分变换,事情将会正常工作。对于自动微分变换,我们将得到一个清晰的错误信息,说明为什么我们不能对 custom_jvpcustom_vjp 闭包的值进行微分:

检测到对闭包值的 custom_jvp 函数进行微分。这是不支持的,因为自定义 JVP 规则仅指定了如何对显式输入参数的 custom_jvp 函数进行微分。

尝试将闭包值作为参数传递给 custom_jvp 函数,并调整 custom_jvp 规则。

在加强和强化 custom_jvpcustom_vjp 的过程中,我们发现允许 custom_vjp 在其 nondiff_argnums 中接受 Tracer 会需要大量的簿记工作:我们需要重写用户的 fwd 函数以返回残差值,并重写用户的 bwd 函数以正常接受这些残差(而不是像 nondiff_argnums 那样将它们作为特殊的前导参数接受)。这似乎可能还可以管理,直到你考虑到我们必须如何处理任意的 pytrees!此外,这种复杂性是不必要的:如果用户代码将类似数组的不可微分参数像普通参数和残差一样处理,一切都已经正常工作。(在 #4039 之前,JAX 可能会对涉及整数值输入和输出的自动微分提出抱怨,但在 #4039 之后,这些将只是正常工作!)

custom_vjp 不同,使 custom_jvpnondiff_argnums 参数一起工作很容易,这些参数是 Tracer。因此,这些更新只需要在 custom_vjp 中进行。