jax.numpy.diag_indices#
- jax.numpy.diag_indices(n, ndim=2)[源代码][源代码]#
返回用于访问多维数组主对角线的索引。
JAX 实现的
numpy.diag_indices()
。- 参数:
- 返回:
一个数组的元组,每个数组的长度为 n,包含访问主对角线的索引。
- 返回类型:
示例
>>> jnp.diag_indices(3) (Array([0, 1, 2], dtype=int32), Array([0, 1, 2], dtype=int32)) >>> jnp.diag_indices(4, ndim=3) (Array([0, 1, 2, 3], dtype=int32), Array([0, 1, 2, 3], dtype=int32), Array([0, 1, 2, 3], dtype=int32))