jax.numpy.diag_indices

jax.numpy.diag_indices#

jax.numpy.diag_indices(n, ndim=2)[源代码][源代码]#

返回用于访问多维数组主对角线的索引。

JAX 实现的 numpy.diag_indices()

参数:
  • n (int) – int. 方阵每个维度的大小。

  • ndim (int) – 可选, int, 默认=2。数组的维度数量。

返回:

一个数组的元组,每个数组的长度为 n,包含访问主对角线的索引。

返回类型:

tuple[Array, …]

示例

>>> jnp.diag_indices(3)
(Array([0, 1, 2], dtype=int32), Array([0, 1, 2], dtype=int32))
>>> jnp.diag_indices(4, ndim=3)
(Array([0, 1, 2, 3], dtype=int32),
Array([0, 1, 2, 3], dtype=int32),
Array([0, 1, 2, 3], dtype=int32))