jax.scipy.linalg.block_diag

目录

jax.scipy.linalg.block_diag#

jax.scipy.linalg.block_diag(*arrs)[源代码][源代码]#

从输入数组创建一个块对角矩阵。

JAX 实现的 scipy.linalg.block_diag()

参数:

*arrs (ArrayLike) – 最多两维的数组

返回:

通过将输入数组沿对角线放置构建的二维块对角阵列。

返回类型:

Array

示例

>>> A = jnp.ones((1, 1))
>>> B = jnp.ones((2, 2))
>>> C = jnp.ones((3, 3))
>>> jax.scipy.linalg.block_diag(A, B, C)
Array([[1., 0., 0., 0., 0., 0.],
       [0., 1., 1., 0., 0., 0.],
       [0., 1., 1., 0., 0., 0.],
       [0., 0., 0., 1., 1., 1.],
       [0., 0., 0., 1., 1., 1.],
       [0., 0., 0., 1., 1., 1.]], dtype=float32)