jax.grad

目录

jax.grad#

jax.grad(fun, argnums=0, has_aux=False, holomorphic=False, allow_int=False, reduce_axes=())[源代码][源代码]#

创建一个评估 fun 梯度的函数。

参数:
  • fun (Callable) – 要微分的函数。其参数在由 argnums 指定的位置应为数组、标量或标准 Python 容器。在 argnums 指定的位置的参数数组必须是近似类型(即浮点数或复数类型)。它应返回一个标量(包括形状为 () 的数组,但不包括形状为 (1,) 等的数组)。

  • argnums (int | Sequence[int]) – 可选,整数或整数序列。指定要对其进行微分的定位参数(默认值为0)。

  • has_aux (bool) – 可选,布尔值。指示 fun 是否返回一个对,其中第一个元素被认为是需要微分的数学函数的输出,第二个元素是辅助数据。默认为 False。

  • holomorphic (bool) – 可选, bool. 指示 fun 是否承诺为全纯函数。如果为 True,输入和输出必须是复数。默认为 False。

  • allow_int (bool) – 可选, bool. 是否允许对整数值输入进行微分。整数输入的梯度将具有平凡的向量空间数据类型(float0)。默认 False。

  • reduce_axes (Sequence[AxisName])

返回:

一个与 fun 具有相同参数的函数,用于计算 fun 的梯度。如果 argnums 是整数,则梯度与该整数指示的位置参数具有相同的形状和类型。如果 argnums 是整数元组,则梯度是与相应参数具有相同形状和类型的值元组。如果 has_aux 为 True,则返回一对 (梯度, 辅助数据)。

返回类型:

Callable

例如:

>>> import jax
>>>
>>> grad_tanh = jax.grad(jax.numpy.tanh)
>>> print(grad_tanh(0.2))
0.961043