jax.numpy.take_along_axis

jax.numpy.take_along_axis#

jax.numpy.take_along_axis(arr, indices, axis, mode=None, fill_value=None)[源代码][源代码]#

从数组中提取元素。

JAX 实现的 numpy.take_along_axis(),基于 jax.lax.gather() 实现。JAX 的行为在越界索引的情况下与 NumPy 不同;请参见下面的 mode 参数。

参数:
  • a – 从中获取值的数组。

  • indices (ArrayLike) – 整数索引数组。如果 axisNone,则必须是一维的。如果 axis 不是 None,则必须满足 a.ndim == indices.ndim,并且 a 必须在与 axis 不同的维度上与 indices 广播兼容。

  • axis (int | None) – 要沿其取值的轴。如果未指定,将在应用索引之前将数组展平。

  • mode (str | lax.GatherScatterMode | None) – 越界索引模式,可以是 "fill""clip"。默认的 mode="fill" 对越界索引返回无效值(例如 NaN)。有关 mode 选项的更多讨论,请参见 jax.numpy.ndarray.at

  • arr (ArrayLike)

  • fill_value (StaticScalar | None)

返回:

a 中提取的值数组。

返回类型:

Array

参见

示例

>>> x = jnp.array([[1., 2., 3.],
...                [4., 5., 6.]])
>>> indices = jnp.array([[0, 2],
...                      [1, 0]])
>>> jnp.take_along_axis(x, indices, axis=1)
Array([[1., 3.],
       [5., 4.]], dtype=float32)
>>> x[jnp.arange(2)[:, None], indices]  # equivalent via indexing syntax
Array([[1., 3.],
       [5., 4.]], dtype=float32)

越界索引填充为无效值。对于浮点输入,这是 NaN

>>> indices = jnp.array([[1, 0, 2]])
>>> jnp.take_along_axis(x, indices, axis=0)
Array([[ 4.,  2., nan]], dtype=float32)
>>> x.at[indices, jnp.arange(3)].get(
...     mode='fill', fill_value=jnp.nan)  # equivalent via indexing syntax
Array([[ 4.,  2., nan]], dtype=float32)

take_along_axis 对于从多维 argsorts 和 arg 减少中提取值非常有帮助。例如,我们在这里沿着一个轴计算 argsort() 索引,并使用 take_along_axis 来构建排序后的数组:

>>> x = jnp.array([[5, 3, 4],
...                [2, 7, 6]])
>>> indices = jnp.argsort(x, axis=1)
>>> indices
Array([[1, 2, 0],
       [0, 2, 1]], dtype=int32)
>>> jnp.take_along_axis(x, indices, axis=1)
Array([[3, 4, 5],
       [2, 6, 7]], dtype=int32)

同样地,我们可以使用 argmin() 并设置 keepdims=True,然后使用 take_along_axis 来提取最小值:

>>> idx = jnp.argmin(x, axis=1, keepdims=True)
>>> idx
Array([[1],
       [0]], dtype=int32)
>>> jnp.take_along_axis(x, idx, axis=1)
Array([[3],
       [2]], dtype=int32)