jax.numpy.ravel_多索引#
- jax.numpy.ravel_multi_index(multi_index, dims, mode='raise', order='C')[源代码][源代码]#
将多维索引转换为扁平索引。
JAX 实现的
numpy.ravel_multi_index()
- 参数:
multi_index (Sequence[ArrayLike]) – 包含每个维度索引的整数数组的序列。
dims (Sequence[int]) – 整数大小的序列;必须满足
len(dims) == len(multi_index)
mode (str) – 如何处理越界索引。选项有 -
"raise"
(默认): 抛出 ValueError。此模式与jit()
或其他 JAX 转换不兼容。 -"clip"
: 将越界索引裁剪到有效范围内。 -"wrap"
: 将越界索引环绕到有效范围内。order (str) –
"C"
(默认) 或"F"
,指定是采用 C 风格的行优先顺序还是 Fortran 风格的列优先顺序。
- 返回:
展平索引的数组
- 返回类型:
参见
jax.numpy.unravel_index()
: 此函数的逆函数。示例
定义一个二维数组和一个偶数值索引序列:
>>> x = jnp.array([[2., 3., 4.], ... [5., 6., 7.]]) >>> indices = jnp.where(x % 2 == 0) >>> indices (Array([0, 0, 1], dtype=int32), Array([0, 2, 1], dtype=int32)) >>> x[indices] Array([2., 4., 6.], dtype=float32)
计算展平后的索引:
>>> indices_flat = jnp.ravel_multi_index(indices, x.shape) >>> indices_flat Array([0, 2, 4], dtype=int32)
这些扁平化的索引可以用来从扁平化的
x
数组中提取相同的值:>>> x_flat = x.ravel() >>> x_flat Array([2., 3., 4., 5., 6., 7.], dtype=float32) >>> x_flat[indices_flat] Array([2., 4., 6.], dtype=float32)
原始索引可以通过
unravel_index()
恢复:>>> jnp.unravel_index(indices_flat, x.shape) (Array([0, 0, 1], dtype=int32), Array([0, 2, 1], dtype=int32))