jax.numpy.ravel_多索引

jax.numpy.ravel_多索引#

jax.numpy.ravel_multi_index(multi_index, dims, mode='raise', order='C')[源代码][源代码]#

将多维索引转换为扁平索引。

JAX 实现的 numpy.ravel_multi_index()

参数:
  • multi_index (Sequence[ArrayLike]) – 包含每个维度索引的整数数组的序列。

  • dims (Sequence[int]) – 整数大小的序列;必须满足 len(dims) == len(multi_index)

  • mode (str) – 如何处理越界索引。选项有 - "raise" (默认): 抛出 ValueError。此模式与 jit() 或其他 JAX 转换不兼容。 - "clip": 将越界索引裁剪到有效范围内。 - "wrap": 将越界索引环绕到有效范围内。

  • order (str) – "C" (默认) 或 "F",指定是采用 C 风格的行优先顺序还是 Fortran 风格的列优先顺序。

返回:

展平索引的数组

返回类型:

Array

参见

jax.numpy.unravel_index(): 此函数的逆函数。

示例

定义一个二维数组和一个偶数值索引序列:

>>> x = jnp.array([[2., 3., 4.],
...                [5., 6., 7.]])
>>> indices = jnp.where(x % 2 == 0)
>>> indices
(Array([0, 0, 1], dtype=int32), Array([0, 2, 1], dtype=int32))
>>> x[indices]
Array([2., 4., 6.], dtype=float32)

计算展平后的索引:

>>> indices_flat = jnp.ravel_multi_index(indices, x.shape)
>>> indices_flat
Array([0, 2, 4], dtype=int32)

这些扁平化的索引可以用来从扁平化的 x 数组中提取相同的值:

>>> x_flat = x.ravel()
>>> x_flat
Array([2., 3., 4., 5., 6., 7.], dtype=float32)
>>> x_flat[indices_flat]
Array([2., 4., 6.], dtype=float32)

原始索引可以通过 unravel_index() 恢复:

>>> jnp.unravel_index(indices_flat, x.shape)
(Array([0, 0, 1], dtype=int32), Array([0, 2, 1], dtype=int32))