jax.custom_jvp.defjvp

jax.custom_jvp.defjvp#

custom_jvp.defjvp(jvp, symbolic_zeros=False)[源代码][源代码]#

为此实例表示的函数定义一个自定义的 JVP 规则。

参数:
  • jvp (Callable[..., tuple[ReturnValue, ReturnValue]]) – 一个表示自定义 JVP 规则的 Python 可调用对象。当没有 nondiff_argnums 时,jvp 函数应接受两个参数,其中第一个是原始输入的元组,第二个是切线输入的元组。两个元组的长度都等于 custom_jvp 函数的参数数量。jvp 函数应输出一个对,其中第一个元素是原始输出,第二个元素是切线输出。输入和输出元组的元素可以是数组或任何嵌套的元组/列表/字典。

  • symbolic_zeros (bool) – 布尔值,指示规则是否应在切线参数中传递表示静态符号零的对象,以对应未扰动的值;否则,仅传递标准 JAX 类型(例如类似数组的对象)。将此选项设置为 True 允许 JVP 规则检测某些输入是否不参与微分,但代价是需要对这些对象进行特殊处理(例如,不能将它们传递给 jax.numpy 函数)。默认 False

返回:

返回 jvp 以便 defjvp 可以用作装饰器。

返回类型:

Callable[…, tuple[ReturnValue, ReturnValue]]

示例

>>> @jax.custom_jvp
... def f(x, y):
...   return jnp.sin(x) * y
...
>>> @f.defjvp
... def f_jvp(primals, tangents):
...   x, y = primals
...   x_dot, y_dot = tangents
...   primal_out = f(x, y)
...   tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot
...   return primal_out, tangent_out
>>> x = jnp.float32(1.0)
>>> y = jnp.float32(2.0)
>>> with jnp.printoptions(precision=2):
...   print(jax.value_and_grad(f)(x, y))
(Array(1.68, dtype=float32), Array(1.08, dtype=float32))