jax.scipy.linalg.expm

目录

jax.scipy.linalg.expm#

jax.scipy.linalg.expm(A, *, upper_triangular=False, max_squarings=16)[源代码][源代码]#

计算矩阵指数

JAX 实现的 scipy.linalg.expm()

参数:
  • A (ArrayLike) – 形状为 (..., N, N) 的数组

  • upper_triangular (bool) – 如果为真,则假设 A 是上三角矩阵。默认=False。

  • max_squarings (int) – 缩放和平方法近似方法中的平方次数(默认:16)。

返回:

形状为 (..., N, N) 的数组,包含矩阵 A 的指数。

返回类型:

Array

备注

这使用了缩放和平方法近似,其计算复杂度由可选的 max_squarings 参数控制。理论上,所需的平方次数为 max(0, ceil(log2(norm(A))) - c),其中 norm(A) 是 L1 范数,对于 float64/complex128 类型,c=2.42,对于 float32/complex64 类型,c=1.97

示例

expm 是矩阵指数,并且具有与更熟悉的标量指数类似的性质。对于标量 ab\(e^{a + b} = e^a e^b\)。然而,对于矩阵,这个性质仅在 AB 可交换(AB = BA)时成立。在这种情况下,expm(A+B) = expm(A) @ expm(B)

>>> A = jnp.array([[2, 0],
...                [0, 1]])
>>> B = jnp.array([[3, 0],
...                [0, 4]])
>>> jnp.allclose(jax.scipy.linalg.expm(A+B),
...              jax.scipy.linalg.expm(A) @ jax.scipy.linalg.expm(B),
...              rtol=0.0001)
Array(True, dtype=bool)

如果矩阵 X 是可逆的,那么 expm(X @ A @ inv(X)) = X @ expm(A) @ inv(X)

>>> X = jnp.array([[3, 1],
...                [2, 5]])
>>> X_inv = jax.scipy.linalg.inv(X)
>>> jnp.allclose(jax.scipy.linalg.expm(X @ A @ X_inv),
...              X @ jax.scipy.linalg.expm(A) @ X_inv)
Array(True, dtype=bool)