jax.scipy.linalg.solve#
- jax.scipy.linalg.solve(a, b, lower=False, overwrite_a=False, overwrite_b=False, debug=False, check_finite=True, assume_a='gen')[源代码][源代码]#
求解线性方程组
JAX 实现的
scipy.linalg.solve()
。这解决了给定
a
和b
的线性方程组a @ x = b
以求解x
。- 参数:
a (ArrayLike) – 形状为
(..., N, N)
的数组。b (ArrayLike) – 形状为
(..., N)
或(..., N, M)
的数组lower (bool) – 仅在
assume_a != 'gen'
时引用。如果为 True,则仅使用输入的下三角部分;如果为 False(默认),则仅使用上三角部分。assume_a (str) – 指定
a
的属性可以假设什么。选项包括: -"gen"
:通用矩阵(默认) -"sym"
:对称矩阵 -"her"
:厄米矩阵 -"pos"
:正定矩阵overwrite_a (bool) – JAX 未使用
overwrite_b (bool) – JAX 未使用
debug (bool) – JAX 未使用
check_finite (bool) – JAX 未使用
- 返回:
一个与
b
形状相同的数组,包含线性系统的解。- 返回类型:
参见
jax.scipy.linalg.lu_solve()
: 通过LU分解求解。jax.scipy.linalg.cho_solve()
: 通过 Cholesky 分解求解。jax.scipy.linalg.solve_triangular()
: 解一个三角系统。jax.numpy.linalg.solve()
: 用于求解线性系统的 NumPy 风格 API。jax.lax.custom_linear_solve()
: 无矩阵线性求解器。
示例
一个简单的 3x3 线性系统:
>>> A = jnp.array([[1., 2., 3.], ... [2., 4., 2.], ... [3., 2., 1.]]) >>> b = jnp.array([14., 16., 10.]) >>> x = jax.scipy.linalg.solve(A, b) >>> x Array([1., 2., 3.], dtype=float32)
确认结果解决了系统:
>>> jnp.allclose(A @ x, b) Array(True, dtype=bool)