jax.numpy.take_along_axis#
- jax.numpy.take_along_axis(arr, indices, axis, mode=None, fill_value=None)[源代码][源代码]#
从数组中提取元素。
JAX 实现的
numpy.take_along_axis()
,基于jax.lax.gather()
实现。JAX 的行为在越界索引的情况下与 NumPy 不同;请参见下面的mode
参数。- 参数:
a – 从中获取值的数组。
indices (ArrayLike) – 整数索引数组。如果
axis
是None
,则必须是一维的。如果axis
不是 None,则必须满足a.ndim == indices.ndim
,并且a
必须在与axis
不同的维度上与indices
广播兼容。axis (int | None) – 要沿其取值的轴。如果未指定,将在应用索引之前将数组展平。
mode (str | lax.GatherScatterMode | None) – 越界索引模式,可以是
"fill"
或"clip"
。默认的mode="fill"
对越界索引返回无效值(例如 NaN)。有关mode
选项的更多讨论,请参见jax.numpy.ndarray.at
。arr (ArrayLike)
fill_value (StaticScalar | None)
- 返回:
从
a
中提取的值数组。- 返回类型:
参见
jax.numpy.ndarray.at
: 通过索引语法获取值。jax.numpy.take()
: 在每个轴切片上取相同的索引。
示例
>>> x = jnp.array([[1., 2., 3.], ... [4., 5., 6.]]) >>> indices = jnp.array([[0, 2], ... [1, 0]]) >>> jnp.take_along_axis(x, indices, axis=1) Array([[1., 3.], [5., 4.]], dtype=float32) >>> x[jnp.arange(2)[:, None], indices] # equivalent via indexing syntax Array([[1., 3.], [5., 4.]], dtype=float32)
越界索引填充为无效值。对于浮点输入,这是 NaN:
>>> indices = jnp.array([[1, 0, 2]]) >>> jnp.take_along_axis(x, indices, axis=0) Array([[ 4., 2., nan]], dtype=float32) >>> x.at[indices, jnp.arange(3)].get( ... mode='fill', fill_value=jnp.nan) # equivalent via indexing syntax Array([[ 4., 2., nan]], dtype=float32)
take_along_axis
对于从多维 argsorts 和 arg 减少中提取值非常有帮助。例如,我们在这里沿着一个轴计算argsort()
索引,并使用take_along_axis
来构建排序后的数组:>>> x = jnp.array([[5, 3, 4], ... [2, 7, 6]]) >>> indices = jnp.argsort(x, axis=1) >>> indices Array([[1, 2, 0], [0, 2, 1]], dtype=int32) >>> jnp.take_along_axis(x, indices, axis=1) Array([[3, 4, 5], [2, 6, 7]], dtype=int32)
同样地,我们可以使用
argmin()
并设置keepdims=True
,然后使用take_along_axis
来提取最小值:>>> idx = jnp.argmin(x, axis=1, keepdims=True) >>> idx Array([[1], [0]], dtype=int32) >>> jnp.take_along_axis(x, idx, axis=1) Array([[3], [2]], dtype=int32)