jax.hessian#
- jax.hessian(fun, argnums=0, has_aux=False, holomorphic=False)[源代码][源代码]#
fun
的 Hessian 矩阵,以密集数组形式表示。- 参数:
- 返回:
一个与
fun
具有相同参数的函数,用于计算fun
的 Hessian 矩阵。- 返回类型:
Callable
>>> import jax >>> >>> g = lambda x: x[0]**3 - 2*x[0]*x[1] - x[1]**6 >>> print(jax.hessian(g)(jax.numpy.array([1., 2.]))) [[ 6. -2.] [ -2. -480.]]
hessian()
是对通常 Hessian 定义的泛化,支持嵌套的 Python 容器(即 pytrees)作为输入和输出。jax.hessian(fun)(x)
的树结构是通过将fun(x)
的结构与x
的结构的两份副本的树积形成的。两个树结构的树积是通过将第一个树的每个叶子替换为第二个树的副本形成的。例如:>>> import jax.numpy as jnp >>> f = lambda dct: {"c": jnp.power(dct["a"], dct["b"])} >>> print(jax.hessian(f)({"a": jnp.arange(2.) + 1., "b": jnp.arange(2.) + 2.})) {'c': {'a': {'a': Array([[[ 2., 0.], [ 0., 0.]], [[ 0., 0.], [ 0., 12.]]], dtype=float32), 'b': Array([[[ 1. , 0. ], [ 0. , 0. ]], [[ 0. , 0. ], [ 0. , 12.317766]]], dtype=float32)}, 'b': {'a': Array([[[ 1. , 0. ], [ 0. , 0. ]], [[ 0. , 0. ], [ 0. , 12.317766]]], dtype=float32), 'b': Array([[[0. , 0. ], [0. , 0. ]], [[0. , 0. ], [0. , 3.843624]]], dtype=float32)}}}
因此,
jax.hessian(fun)(x)
树结构中的每个叶子对应于fun(x)
的一个叶子和x
的一对叶子。对于jax.hessian(fun)(x)
中的每个叶子,如果对应的fun(x)
的数组叶子具有形状(out_1, out_2, ...)
,而对应的x
的数组叶子分别具有形状(in_1_1, in_1_2, ...)
和(in_2_1, in_2_2, ...)
,那么 Hessian 叶子具有形状(out_1, out_2, ..., in_1_1, in_1_2, ..., in_2_1, in_2_2, ...)
。换句话说,Python 树结构表示 Hessian 的块结构,块由输入和输出 pytrees 确定。特别是,当函数输入
x
和输出fun(x)
各自是一个单独的数组时(不涉及 pytrees),会生成一个数组,如上面的g
示例所示。如果fun(x)
的形状为(out1, out2, ...)
,而x
的形状为(in1, in2, ...)
,那么jax.hessian(fun)(x)
的形状为(out1, out2, ..., in1, in2, ..., in1, in2, ...)
。要将 pytrees 展平为 1D 向量,请考虑使用jax.flatten_util.flatten_pytree()
。