jax.自定义_jvp#
- class jax.custom_jvp(fun, nondiff_argnums=())[源代码][源代码]#
为自定义 JVP 规则定义设置一个可 JAX 变换的函数。
这个类旨在用作函数装饰器。实例是可调用的,其行为类似于装饰器所应用的基础函数,除非应用了微分变换(如
jax.jvp()
或jax.grad()
),在这种情况下,将使用自定义的用户提供的 JVP 规则函数,而不是跟踪并执行基础函数实现的自动微分。有两种实例方法可用于定义自定义 JVP 规则:
defjvp()
用于为函数的所有输入定义一个 单一 自定义 JVP 规则,以及为了方便起见defjvps()
,它包装了defjvp()
,并允许您为函数相对于其每个参数的偏导数提供单独的定义。例如:
@jax.custom_jvp def f(x, y): return jnp.sin(x) * y @f.defjvp def f_jvp(primals, tangents): x, y = primals x_dot, y_dot = tangents primal_out = f(x, y) tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot return primal_out, tangent_out
欲了解更多详细介绍,请参阅教程。
- 参数:
fun (Callable[..., ReturnValue])
nondiff_argnums (Sequence[int])
- __init__(fun, nondiff_argnums=())[源代码][源代码]#
- 参数:
fun (Callable[..., ReturnValue])
nondiff_argnums (Sequence[int])
方法
__init__
(fun[, nondiff_argnums])defjvp
(jvp[, symbolic_zeros])为此实例表示的函数定义一个自定义的 JVP 规则。
defjvps
(*jvps)为每个参数分别定义JVP的便捷包装器。
属性
jvp
symbolic_zeros
fun
nondiff_argnums