jax.numpy.linalg.inv#
- jax.numpy.linalg.inv(a)[源代码][源代码]#
返回一个方阵的逆矩阵
JAX 实现的
numpy.linalg.inv()
。- 参数:
a (ArrayLike) – 形状为
(..., N, N)
的数组,指定要被求逆的方阵。- 返回:
形状为
(..., N, N)
的数组,包含输入的逆矩阵。- 返回类型:
备注
在大多数情况下,显式计算矩阵的逆是不明智的。例如,要计算
x = inv(A) @ b
,使用直接求解,如jax.scipy.linalg.solve()
,会更加高效且数值精确。参见
jax.scipy.linalg.inv()
: 用于矩阵逆的 SciPy 风格 APIjax.numpy.linalg.solve()
: 直接线性求解器
示例
计算一个3x3矩阵的逆
>>> a = jnp.array([[1., 2., 3.], ... [2., 4., 2.], ... [3., 2., 1.]]) >>> a_inv = jnp.linalg.inv(a) >>> a_inv Array([[ 0. , -0.25 , 0.5 ], [-0.25 , 0.5 , -0.25000003], [ 0.5 , -0.25 , 0. ]], dtype=float32)
检查与逆矩阵相乘是否得到单位矩阵:
>>> jnp.allclose(a @ a_inv, jnp.eye(3), atol=1E-5) Array(True, dtype=bool)
将逆矩阵乘以向量
b
,以找到a @ x = b
的解。>>> b = jnp.array([1., 4., 2.]) >>> a_inv @ b Array([ 0. , 1.25, -0.5 ], dtype=float32)
然而,需要注意的是,在这种情况下显式计算逆矩阵可能会导致性能下降和精度损失,因为问题规模增大。相反,你应该使用像
jax.numpy.linalg.solve()
这样的直接求解器:>>> jnp.linalg.solve(a, b) Array([ 0. , 1.25, -0.5 ], dtype=float32)