jax.scipy.linalg.funm

目录

jax.scipy.linalg.funm#

jax.scipy.linalg.funm(A, func, disp=True)[源代码][源代码]#

评估一个矩阵值函数

JAX 实现的 scipy.linalg.funm()

参数:
  • A (ArrayLike) – 形状为 (N, N) 的数组,用于计算该函数。

  • func (Callable[[Array], Array]) – 可调用对象,接受一个标量参数并返回一个标量结果。表示要在矩阵 A 的特征值上评估的函数。

  • disp (bool) – 如果为真(默认),错误信息不会返回。与 scipy 的版本不同,JAX 不会尝试在运行时显示信息。

  • compute_expm – (N, N) array_like 或 None,可选。如果提供,则为 A 的矩阵指数。这在 func 是指数函数时用于提高效率。如果未提供,则会在内部计算。默认为 None。

返回:

A 形状相同的数组,包含 funcA 的特征值上计算的结果。

返回类型:

Array | tuple[Array, Array]

备注

JAX 实现的返回 dtype 可能与 scipy 的不同;具体来说,在数组值的所有虚部都接近于零的情况下,SciPy 函数可能返回一个实值数组,而 JAX 实现将返回一个复值数组。

示例

应用一个任意的矩阵函数:

>>> A = jnp.array([[1., 2.], [3., 4.]])
>>> def func(x):
...   return jnp.sin(x) + 2 * jnp.cos(x)
>>> jax.scipy.linalg.funm(A, func)  
Array([[ 1.2452652 +0.j, -0.3701772 +0.j],
       [-0.55526584+0.j,  0.6899995 +0.j]], dtype=complex64)

比较两种计算矩阵指数的方法:

>>> expA_1 = jax.scipy.linalg.funm(A, jnp.exp)
>>> expA_2 = jax.scipy.linalg.expm(A)
>>> jnp.allclose(expA_1, expA_2, rtol=1E-4)
Array(True, dtype=bool)