jax.numpy.linalg.solve#
- jax.numpy.linalg.solve(a, b)[源代码][源代码]#
求解线性方程组
JAX 实现的
numpy.linalg.solve()
。这解决了给定
a
和b
的线性方程组a @ x = b
以求解x
。- 参数:
a (ArrayLike) – 形状为
(..., N, N)
的数组。b (ArrayLike) – 形状为
(N,)
的数组(用于一维右侧)或(..., N, M)
的数组(用于批量二维右侧)。
- 返回:
包含线性求解结果的数组。如果
b
的形状是(N,)
,则结果的形状是(..., N)
,否则结果的形状是(..., N, M)
。- 返回类型:
参见
jax.scipy.linalg.solve()
: 用于求解线性系统的 SciPy 风格 API。jax.lax.custom_linear_solve()
: 无矩阵线性求解器。
示例
一个简单的 3x3 线性系统:
>>> A = jnp.array([[1., 2., 3.], ... [2., 4., 2.], ... [3., 2., 1.]]) >>> b = jnp.array([14., 16., 10.]) >>> x = jnp.linalg.solve(A, b) >>> x Array([1., 2., 3.], dtype=float32)
确认结果解决了系统:
>>> jnp.allclose(A @ x, b) Array(True, dtype=bool)