jax.scipy.linalg.expm_frechet

jax.scipy.linalg.expm_frechet#

jax.scipy.linalg.expm_frechet(A: ArrayLike, E: ArrayLike, *, method: str | None = None, compute_expm: Literal[True] = True) tuple[Array, Array][源代码][源代码]#
jax.scipy.linalg.expm_frechet(A: ArrayLike, E: ArrayLike, *, method: str | None = None, compute_expm: Literal[False]) Array
jax.scipy.linalg.expm_frechet(A: ArrayLike, E: ArrayLike, *, method: str | None = None, compute_expm: bool = True) Array | tuple[Array, Array]

计算矩阵指数的Frechet导数。

JAX 实现的 scipy.linalg.expm_frechet()

参数:
  • A – 形状为 (..., N, N) 的数组

  • E – 形状为 (..., N, N) 的数组;指定导数的方向。

  • compute_expm – 如果为真(默认),则计算并返回 expm(A)

  • method – 被 JAX 忽略

返回:

如果 compute_expm 为 True,则返回元组 (expm_A, expm_frechet_AE),否则返回数组 expm_frechet_AE。两个返回的数组形状均为 (..., N, N)

示例

我们可以使用这个API来计算矩阵 A 的指数,以及它在方向 E 上的导数:

>>> key1, key2 = jax.random.split(jax.random.key(3372))
>>> A = jax.random.normal(key1, (3, 3))
>>> E = jax.random.normal(key2, (3, 3))
>>> expmA, expm_frechet_AE = jax.scipy.linalg.expm_frechet(A, E)

这同样可以使用 JAX 的自动微分方法来计算;在这里,我们将使用 jax.jvp() 计算 expm()E 方向上的导数,并得到相同的结果:

>>> expmA2, expm_frechet_AE2 = jax.jvp(jax.scipy.linalg.expm, (A,), (E,))
>>> jnp.allclose(expmA, expmA2)
Array(True, dtype=bool)
>>> jnp.allclose(expm_frechet_AE, expm_frechet_AE2)
Array(True, dtype=bool)