jax.custom_jvp.defjvps

jax.custom_jvp.defjvps#

custom_jvp.defjvps(*jvps)[源代码][源代码]#

为每个参数分别定义JVP的便捷包装器。

这个便捷的包装器不能与 nondiff_argnums 一起使用。

参数:

*jvps (Callable[..., ReturnValue] | None) – 一系列函数,每个函数对应 custom_jvp 函数的每个位置参数。每个函数以相应原始输入的切线值、原始输出和原始输入作为参数。请参见下面的示例。

返回:

无。

返回类型:

None

示例

>>> @jax.custom_jvp
... def f(x, y):
...   return jnp.sin(x) * y
...
>>> f.defjvps(lambda x_dot, primal_out, x, y: jnp.cos(x) * x_dot * y,
...           lambda y_dot, primal_out, x, y: jnp.sin(x) * y_dot)
>>> x = jnp.float32(1.0)
>>> y = jnp.float32(2.0)
>>> with jnp.printoptions(precision=2):
...   print(jax.value_and_grad(f)(x, y))
(Array(1.68, dtype=float32), Array(1.08, dtype=float32))