jax.scipy.linalg.block_diag#
- jax.scipy.linalg.block_diag(*arrs)[源代码][源代码]#
从输入数组创建一个块对角矩阵。
JAX 实现的
scipy.linalg.block_diag()
。- 参数:
*arrs (ArrayLike) – 最多两维的数组
- 返回:
通过将输入数组沿对角线放置构建的二维块对角阵列。
- 返回类型:
示例
>>> A = jnp.ones((1, 1)) >>> B = jnp.ones((2, 2)) >>> C = jnp.ones((3, 3)) >>> jax.scipy.linalg.block_diag(A, B, C) Array([[1., 0., 0., 0., 0., 0.], [0., 1., 1., 0., 0., 0.], [0., 1., 1., 0., 0., 0.], [0., 0., 0., 1., 1., 1.], [0., 0., 0., 1., 1., 1.], [0., 0., 0., 1., 1., 1.]], dtype=float32)