jax.numpy.unravel_index

jax.numpy.unravel_index#

jax.numpy.unravel_index(indices, shape)[源代码][源代码]#

将平面索引转换为多维索引。

JAX 实现的 numpy.unravel_index()。JAX 版本在处理越界索引时有所不同:与 NumPy 不同,JAX 支持负索引,并且越界索引会被裁剪到最近的有效值。

参数:
  • indices (ArrayLike) – 扁平索引的整数数组

  • shape (Shape) – 多维数组的形状以进行索引

返回:

解包索引的元组

返回类型:

tuple[Array, …]

参见

jax.numpy.ravel_multi_index():此函数的逆函数。

示例

从一个一维数组值和索引开始:

>>> x = jnp.array([2., 3., 4., 5., 6., 7.])
>>> indices = jnp.array([1, 3, 5])
>>> print(x[indices])
[3. 5. 7.]

现在如果 x 被重塑,unravel_indices 可以用来将扁平索引转换为访问相同条目的索引元组:

>>> shape = (2, 3)
>>> x_2D = x.reshape(shape)
>>> indices_2D = jnp.unravel_index(indices, shape)
>>> indices_2D
(Array([0, 1, 1], dtype=int32), Array([1, 0, 2], dtype=int32))
>>> print(x_2D[indices_2D])
[3. 5. 7.]

逆函数 ravel_multi_index 可以用来获取原始索引:

>>> jnp.ravel_multi_index(indices_2D, shape)
Array([1, 3, 5], dtype=int32)