jax.jvp#
- jax.jvp(fun, primals, tangents, has_aux=False)[源代码][源代码]#
计算
fun的(前向模式)雅可比向量积。- 参数:
fun (Callable) – 要微分的函数。其参数应为数组、标量或标准Python容器中的数组或标量。它应返回一个数组、标量或标准Python容器中的数组或标量。
primals –
fun的雅可比矩阵应在其上进行评估的原始值。应为参数的元组或列表,其长度应等于fun的位置参数的数量。tangents – 应计算其雅可比向量积的切向量。应为切向量的元组或列表,具有与
primals相同的树结构和数组形状。has_aux (bool) – 可选,布尔值。指示
fun是否返回一个对,其中第一个元素被认为是需要微分的数学函数的输出,第二个元素是辅助数据。默认为 False。
- 返回:
如果
has_aux是False,返回一个(primals_out, tangents_out)对,其中primals_out是fun(*primals),而tangents_out是function在primals处评估的雅可比向量积与tangents。tangents_out值具有与primals_out相同的 Python 树结构和形状。如果has_aux是True,返回一个(primals_out, tangents_out, aux)元组,其中aux是fun返回的辅助数据。- 返回类型:
tuple[Any, …]
例如:
>>> import jax >>> >>> primals, tangents = jax.jvp(jax.numpy.sin, (0.1,), (0.2,)) >>> print(primals) 0.09983342 >>> print(tangents) 0.19900084