jax.numpy.linalg.lstsq

目录

jax.numpy.linalg.lstsq#

jax.numpy.linalg.lstsq(a, b, rcond=None, *, numpy_resid=False)[源代码][源代码]#

返回线性方程的最小二乘解。

JAX 实现的 numpy.linalg.lstsq()

参数:
  • a (ArrayLike) – 形状为 (M, N) 的数组,表示系数矩阵。

  • b (ArrayLike) – 形状为 (M,)(M, K) 的数组,表示右侧。

  • rcond (float | None) – 小奇异值的截断比率。小于 rcond * 最大奇异值 的奇异值被视为零。如果为 None(默认),将使用最佳值以减少浮点误差。

  • numpy_resid (bool) – 如果为 True,将以与 NumPy 的 linalg.lstsq 相同的方式计算并返回残差。如果你想精确复制 NumPy 的行为,这是必要的。如果为 False(默认),则使用一种更高效的方法来计算残差。

返回:

数组的元组 (x, resid, rank, s) 其中 - x 是一个形状为 (N,)(N, K) 的数组,包含最小二乘解。 - resid 是形状为 ()(K,) 的平方残差和。 - rank 是矩阵 a 的秩。 - s 是矩阵 a 的奇异值。

返回类型:

tuple[Array, Array, Array, Array]

示例

>>> a = jnp.array([[1, 2],
...                [3, 4]])
>>> b = jnp.array([5, 6])
>>> x, _, _, _ = jnp.linalg.lstsq(a, b)
>>> with jnp.printoptions(precision=3):
...   print(x)
[-4.   4.5]