jax.scipy.linalg.lu#
- jax.scipy.linalg.lu(a: ArrayLike, permute_l: Literal[False] = False, overwrite_a: bool = False, check_finite: bool = True) tuple[Array, Array, Array] [源代码][源代码]#
- jax.scipy.linalg.lu(a: ArrayLike, permute_l: Literal[True], overwrite_a: bool = False, check_finite: bool = True) tuple[Array, Array]
- jax.scipy.linalg.lu(a: ArrayLike, permute_l: bool = False, overwrite_a: bool = False, check_finite: bool = True) tuple[Array, Array] | tuple[Array, Array, Array]
计算LU分解
JAX 实现的
scipy.linalg.lu()
。矩阵 A 的 LU 分解是:
\[A = P L U\]其中 P 是一个置换矩阵,L 是下三角矩阵,U 是上三角矩阵。
- 参数:
a – 形状为
(..., M, N)
的数组进行分解。permute_l – 如果为真,则置换
L
并返回(P @ L, U)
(默认:False)overwrite_a – 不被 JAX 使用
check_finite – 不被 JAX 使用
- 返回:
P
是一个形状为(..., M, M)
的置换矩阵L
是一个形状为(... M, K)
的下三角矩阵U
是一个形状为(..., K, N)
的上三角矩阵
使用
K = min(M, N)
- 返回类型:
如果
permute_l
为 True,则为数组的元组(P @ L, U)
,否则为(P, L, U)
参见
jax.numpy.linalg.lu()
: 用于LU分解的NumPy风格API。jax.lax.linalg.lu()
: 用于LU分解的XLA风格API。jax.scipy.linalg.lu_solve()
: 基于LU分解的线性求解器。
示例
一个 3x3 矩阵的 LU 分解:
>>> a = jnp.array([[1., 2., 3.], ... [5., 4., 2.], ... [3., 2., 1.]]) >>> P, L, U = jax.scipy.linalg.lu(a)
P
是一个置换矩阵:即每行和每列都有一个单独的1
:>>> P Array([[0., 1., 0.], [1., 0., 0.], [0., 0., 1.]], dtype=float32)
L
和U
是下三角矩阵和上三角矩阵:>>> with jnp.printoptions(precision=3): ... print(L) ... print(U) [[ 1. 0. 0. ] [ 0.2 1. 0. ] [ 0.6 -0.333 1. ]] [[5. 4. 2. ] [0. 1.2 2.6 ] [0. 0. 0.667]]
原始矩阵可以通过将这三个矩阵相乘来重建:
>>> a_reconstructed = P @ L @ U >>> jnp.allclose(a, a_reconstructed) Array(True, dtype=bool)