jax.numpy.unravel_index#
- jax.numpy.unravel_index(indices, shape)[源代码][源代码]#
将平面索引转换为多维索引。
JAX 实现的
numpy.unravel_index()
。JAX 版本在处理越界索引时有所不同:与 NumPy 不同,JAX 支持负索引,并且越界索引会被裁剪到最近的有效值。参见
jax.numpy.ravel_multi_index()
:此函数的逆函数。示例
从一个一维数组值和索引开始:
>>> x = jnp.array([2., 3., 4., 5., 6., 7.]) >>> indices = jnp.array([1, 3, 5]) >>> print(x[indices]) [3. 5. 7.]
现在如果
x
被重塑,unravel_indices
可以用来将扁平索引转换为访问相同条目的索引元组:>>> shape = (2, 3) >>> x_2D = x.reshape(shape) >>> indices_2D = jnp.unravel_index(indices, shape) >>> indices_2D (Array([0, 1, 1], dtype=int32), Array([1, 0, 2], dtype=int32)) >>> print(x_2D[indices_2D]) [3. 5. 7.]
逆函数
ravel_multi_index
可以用来获取原始索引:>>> jnp.ravel_multi_index(indices_2D, shape) Array([1, 3, 5], dtype=int32)