jax.自定义_vjp#
- class jax.custom_vjp(fun, nondiff_argnums=())[源代码][源代码]#
为自定义 VJP 规则定义设置一个可 JAX 变换的函数。
此类旨在用作函数装饰器。实例是可调用的,其行为类似于应用了装饰器的底层函数,除非应用了反向模式微分变换(如
jax.grad()
),在这种情况下,将使用用户提供的自定义 VJP 规则函数,而不是跟踪并执行底层函数实现的自动微分。有一个实例方法,defvjp()
,可用于定义自定义 VJP 规则。此装饰器禁止使用前向模式自动微分。
例如:
@jax.custom_vjp def f(x, y): return jnp.sin(x) * y def f_fwd(x, y): return f(x, y), (jnp.cos(x), jnp.sin(x), y) def f_bwd(res, g): cos_x, sin_x, y = res return (cos_x * g * y, sin_x * g) f.defvjp(f_fwd, f_bwd)
欲了解更多详细介绍,请参阅教程。
- 参数:
fun (Callable[..., ReturnValue])
nondiff_argnums (Sequence[int])
- __init__(fun, nondiff_argnums=())[源代码][源代码]#
- 参数:
fun (Callable[..., ReturnValue])
nondiff_argnums (Sequence[int])
方法