jax.scipy.linalg.lu

目录

jax.scipy.linalg.lu#

jax.scipy.linalg.lu(a: ArrayLike, permute_l: Literal[False] = False, overwrite_a: bool = False, check_finite: bool = True) tuple[Array, Array, Array][源代码][源代码]#
jax.scipy.linalg.lu(a: ArrayLike, permute_l: Literal[True], overwrite_a: bool = False, check_finite: bool = True) tuple[Array, Array]
jax.scipy.linalg.lu(a: ArrayLike, permute_l: bool = False, overwrite_a: bool = False, check_finite: bool = True) tuple[Array, Array] | tuple[Array, Array, Array]

计算LU分解

JAX 实现的 scipy.linalg.lu()

矩阵 A 的 LU 分解是:

\[A = P L U\]

其中 P 是一个置换矩阵,L 是下三角矩阵,U 是上三角矩阵。

参数:
  • a – 形状为 (..., M, N) 的数组进行分解。

  • permute_l – 如果为真,则置换 L 并返回 (P @ L, U) (默认:False)

  • overwrite_a – 不被 JAX 使用

  • check_finite – 不被 JAX 使用

返回:

  • P 是一个形状为 (..., M, M) 的置换矩阵

  • L 是一个形状为 (... M, K) 的下三角矩阵

  • U 是一个形状为 (..., K, N) 的上三角矩阵

使用 K = min(M, N)

返回类型:

如果 permute_l 为 True,则为数组的元组 (P @ L, U),否则为 (P, L, U)

参见

示例

一个 3x3 矩阵的 LU 分解:

>>> a = jnp.array([[1., 2., 3.],
...                [5., 4., 2.],
...                [3., 2., 1.]])
>>> P, L, U = jax.scipy.linalg.lu(a)

P 是一个置换矩阵:即每行和每列都有一个单独的 1

>>> P
Array([[0., 1., 0.],
       [1., 0., 0.],
       [0., 0., 1.]], dtype=float32)

LU 是下三角矩阵和上三角矩阵:

>>> with jnp.printoptions(precision=3):
...   print(L)
...   print(U)
[[ 1.     0.     0.   ]
 [ 0.2    1.     0.   ]
 [ 0.6   -0.333  1.   ]]
[[5.    4.    2.   ]
 [0.    1.2   2.6  ]
 [0.    0.    0.667]]

原始矩阵可以通过将这三个矩阵相乘来重建:

>>> a_reconstructed = P @ L @ U
>>> jnp.allclose(a, a_reconstructed)
Array(True, dtype=bool)