jax.scipy.linalg.inv

目录

jax.scipy.linalg.inv#

jax.scipy.linalg.inv(a, overwrite_a=False, check_finite=True)[源代码][源代码]#

返回一个方阵的逆矩阵

JAX 实现的 scipy.linalg.inv()

参数:
  • a (ArrayLike) – 形状为 (..., N, N) 的数组,指定要被求逆的方阵。

  • overwrite_a (bool) – 在 JAX 中未使用

  • check_finite (bool) – 在 JAX 中未使用

返回:

形状为 (..., N, N) 的数组,包含输入的逆矩阵。

返回类型:

Array

备注

在大多数情况下,显式计算矩阵的逆是不明智的。例如,要计算 x = inv(A) @ b,使用直接求解,如 jax.scipy.linalg.solve(),会更加高效且数值精确。

参见

示例

计算一个3x3矩阵的逆

>>> a = jnp.array([[1., 2., 3.],
...                [2., 4., 2.],
...                [3., 2., 1.]])
>>> a_inv = jax.scipy.linalg.inv(a)
>>> a_inv  
Array([[ 0.        , -0.25      ,  0.5       ],
       [-0.25      ,  0.5       , -0.25000003],
       [ 0.5       , -0.25      ,  0.        ]], dtype=float32)

检查与逆矩阵相乘是否得到单位矩阵:

>>> jnp.allclose(a @ a_inv, jnp.eye(3), atol=1E-5)
Array(True, dtype=bool)

将逆矩阵乘以向量 b,以找到 a @ x = b 的解。

>>> b = jnp.array([1., 4., 2.])
>>> a_inv @ b
Array([ 0.  ,  1.25, -0.5 ], dtype=float32)

然而,需要注意的是,在这种情况下显式计算逆矩阵可能会导致性能下降和精度损失,因为问题规模增大。相反,你应该使用像 jax.scipy.linalg.solve() 这样的直接求解器:

>>> jax.scipy.linalg.solve(a, b)
 Array([ 0.  ,  1.25, -0.5 ], dtype=float32)