jax.scipy.linalg.svd

目录

jax.scipy.linalg.svd#

jax.scipy.linalg.svd(a: ArrayLike, full_matrices: bool = True, compute_uv: Literal[True] = True, overwrite_a: bool = False, check_finite: bool = True, lapack_driver: str = 'gesdd') tuple[Array, Array, Array][源代码][源代码]#
jax.scipy.linalg.svd(a: ArrayLike, full_matrices: bool, compute_uv: Literal[False], overwrite_a: bool = False, check_finite: bool = True, lapack_driver: str = 'gesdd') Array
jax.scipy.linalg.svd(a: ArrayLike, full_matrices: bool = True, *, compute_uv: Literal[False], overwrite_a: bool = False, check_finite: bool = True, lapack_driver: str = 'gesdd') Array
jax.scipy.linalg.svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True, overwrite_a: bool = False, check_finite: bool = True, lapack_driver: str = 'gesdd') Array | tuple[Array, Array, Array]

计算奇异值分解。

JAX 实现的 scipy.linalg.svd()

矩阵 A 的奇异值分解(SVD)由以下公式给出

\[A = U\Sigma V^H\]
  • \(U\) 包含左奇异向量,并且满足 \(U^HU=I\)

  • \(V\) 包含右奇异向量,并且满足 \(V^HV=I\)

  • \(\Sigma\) 是一个奇异值的对角矩阵。

参数:
  • a – 输入数组,形状为 (..., N, M)

  • full_matrices – 如果为 True(默认),则计算完整的矩阵;即 uvh 的形状为 (..., N, N)(..., M, M)。如果为 False,则形状为 (..., N, K)(..., K, M),其中 K = min(N, M)

  • compute_uv – 如果为 True(默认),返回完整的 SVD (u, s, vh)。如果为 False,则仅返回奇异值 s

  • overwrite_a – JAX 未使用

  • check_finite – JAX 未使用

  • lapack_driver – JAX 未使用

返回:

如果 compute_uv 为 True,则为数组的元组 (u, s, vh),否则为数组 s。 - u:形状为 (..., N, N) 的左奇异向量,如果 full_matrices 为 True,否则为 (..., N, K)。 - s:形状为 (..., K) 的奇异值。 - vh:形状为 (..., M, M) 的共轭转置右奇异向量,如果 full_matrices 为 True,否则为 (..., K, M)。 其中 K = min(N, M)

参见

示例

考虑一个小实值数组的奇异值分解:

>>> x = jnp.array([[1., 2., 3.],
...                [6., 5., 4.]])
>>> u, s, vt = jax.scipy.linalg.svd(x, full_matrices=False)
>>> s  
Array([9.361919 , 1.8315067], dtype=float32)

奇异向量位于 uv = vt.T 的列中。这些向量是正交的,这可以通过与单位矩阵比较矩阵乘积来证明:

>>> jnp.allclose(u.T @ u, jnp.eye(2), atol=1E-5)
Array(True, dtype=bool)
>>> v = vt.T
>>> jnp.allclose(v.T @ v, jnp.eye(2), atol=1E-5)
Array(True, dtype=bool)

给定SVD,x 可以通过矩阵乘法进行重构:

>>> x_reconstructed = u @ jnp.diag(s) @ vt
>>> jnp.allclose(x_reconstructed, x)
Array(True, dtype=bool)