jax.custom_vjp.defvjp

jax.custom_vjp.defvjp#

custom_vjp.defvjp(fwd, bwd, symbolic_zeros=False, optimize_remat=False)[源代码][源代码]#

为此实例表示的函数定义一个自定义的 VJP 规则。

参数:
  • fwd (Callable[..., tuple[ReturnValue, Any]]) – 一个表示自定义 VJP 规则前向传递的 Python 可调用对象。当没有 nondiff_argnums 时,fwd 函数的输入签名与底层原始函数的输入签名相同。它应该返回一对输出,其中第一个元素表示原始输出,第二个元素表示从正向传递中存储的任何“残差”值,以便在反向传递中由函数 bwd 使用。输入参数和输出对的元素可以是数组或嵌套的元组/列表/字典。

  • bwd (Callable[..., tuple[Any, ...]]) – 一个表示自定义 VJP 规则反向传递的 Python 可调用对象。当没有 nondiff_argnums 时,bwd 函数接受两个参数,其中第一个是 fwd 在前向传递中产生的“残差”值,第二个是与原始函数输出具有相同结构的输出余切。bwd 的输出必须是一个元组,其长度等于原始函数的参数数量,并且元组元素可以是数组或嵌套的元组/列表/字典,以便匹配原始输入参数的结构。

  • symbolic_zeros (bool) – 布尔值,决定是否向 fwdbwd 规则指示符号零。启用此选项允许自定义导数规则检测某些输入以及某些输出余切向量是否不参与微分。如果为 True: * fwd 必须接受,在构成原始函数参数的 pytree 中,每个叶值 x 的位置上,一个具有两个属性的对象(类型为 jax.custom_derivatives.CustomVJPPrimal):valueperturbedvalue 字段是原始的主参数,perturbed 是一个布尔值。perturbed 位指示参数是否参与微分(即,如果为 False,则相应的雅可比矩阵“列”为零)。 * bwd 将在其余切向量参数中传递表示静态符号零的对象,对应于未扰动的值;否则,仅传递标准的 JAX 类型(例如类似数组的对象)。 将此选项设置为 True 允许这些规则检测某些输入和输出是否不参与微分,但代价是特殊处理。例如: * fwd 的签名发生变化,并且它传递的对象不能直接从规则中输出。 * bwd 规则传递的对象不完全类似于数组,并且不能传递给大多数 jax.numpy 函数。 * 原始函数参数中涉及的任何自定义 pytree 节点在其解包函数中必须接受作为 fwd 规则输入叶子的两个字段记录对象。 默认 False

  • optimize_remat (bool) – boolean,一个实验性标志,用于在使用 jax.remat() 时启用自动优化。当 fwd 规则是一个不透明的调用(如 Pallas 内核或自定义调用)时,这将最为有用。默认 False

返回:

无。

返回类型:

None

示例

>>> @jax.custom_vjp
... def f(x, y):
...   return jnp.sin(x) * y
...
>>> def f_fwd(x, y):
...   return f(x, y), (jnp.cos(x), jnp.sin(x), y)
...
>>> def f_bwd(res, g):
...   cos_x, sin_x, y = res
...   return (cos_x * g * y, sin_x * g)
...
>>> f.defvjp(f_fwd, f_bwd)
>>> x = jnp.float32(1.0)
>>> y = jnp.float32(2.0)
>>> with jnp.printoptions(precision=2):
...   print(jax.value_and_grad(f)(x, y))
(Array(1.68, dtype=float32), Array(1.08, dtype=float32))