jax.hessian

目录

jax.hessian#

jax.hessian(fun, argnums=0, has_aux=False, holomorphic=False)[源代码][源代码]#

fun 的 Hessian 矩阵,以密集数组形式表示。

参数:
  • fun (Callable) – 要计算其Hessian的函数。其在``argnums``指定位置的参数应为数组、标量或标准Python容器。它应返回数组、标量或标准Python容器。

  • argnums (int | Sequence[int]) – 可选,整数或整数序列。指定要区分的位置参数(默认 0)。

  • has_aux (bool) – 可选,布尔值。指示 fun 是否返回一个对,其中第一个元素被认为是需要微分的数学函数的输出,第二个元素是辅助数据。默认为 False。

  • holomorphic (bool) – 可选,布尔值。指示 fun 是否承诺为全纯函数。默认值为 False。

返回:

一个与 fun 具有相同参数的函数,用于计算 fun 的 Hessian 矩阵。

返回类型:

Callable

>>> import jax
>>>
>>> g = lambda x: x[0]**3 - 2*x[0]*x[1] - x[1]**6
>>> print(jax.hessian(g)(jax.numpy.array([1., 2.])))
[[   6.   -2.]
 [  -2. -480.]]

hessian() 是对通常 Hessian 定义的泛化,支持嵌套的 Python 容器(即 pytrees)作为输入和输出。jax.hessian(fun)(x) 的树结构是通过将 fun(x) 的结构与 x 的结构的两份副本的树积形成的。两个树结构的树积是通过将第一个树的每个叶子替换为第二个树的副本形成的。例如:

>>> import jax.numpy as jnp
>>> f = lambda dct: {"c": jnp.power(dct["a"], dct["b"])}
>>> print(jax.hessian(f)({"a": jnp.arange(2.) + 1., "b": jnp.arange(2.) + 2.}))
{'c': {'a': {'a': Array([[[ 2.,  0.], [ 0.,  0.]],
                         [[ 0.,  0.], [ 0., 12.]]], dtype=float32),
             'b': Array([[[ 1.      ,  0.      ], [ 0.      ,  0.      ]],
                         [[ 0.      ,  0.      ], [ 0.      , 12.317766]]], dtype=float32)},
       'b': {'a': Array([[[ 1.      ,  0.      ], [ 0.      ,  0.      ]],
                         [[ 0.      ,  0.      ], [ 0.      , 12.317766]]], dtype=float32),
             'b': Array([[[0.      , 0.      ], [0.      , 0.      ]],
                         [[0.      , 0.      ], [0.      , 3.843624]]], dtype=float32)}}}

因此,jax.hessian(fun)(x) 树结构中的每个叶子对应于 fun(x) 的一个叶子和 x 的一对叶子。对于 jax.hessian(fun)(x) 中的每个叶子,如果对应的 fun(x) 的数组叶子具有形状 (out_1, out_2, ...),而对应的 x 的数组叶子分别具有形状 (in_1_1, in_1_2, ...)(in_2_1, in_2_2, ...),那么 Hessian 叶子具有形状 (out_1, out_2, ..., in_1_1, in_1_2, ..., in_2_1, in_2_2, ...)。换句话说,Python 树结构表示 Hessian 的块结构,块由输入和输出 pytrees 确定。

特别是,当函数输入 x 和输出 fun(x) 各自是一个单独的数组时(不涉及 pytrees),会生成一个数组,如上面的 g 示例所示。如果 fun(x) 的形状为 (out1, out2, ...) ,而 x 的形状为 (in1, in2, ...) ,那么 jax.hessian(fun)(x) 的形状为 (out1, out2, ..., in1, in2, ..., in1, in2, ...) 。要将 pytrees 展平为 1D 向量,请考虑使用 jax.flatten_util.flatten_pytree()