jax.numpy.extract

目录

jax.numpy.extract#

jax.numpy.extract(condition, arr, *, size=None, fill_value=0)[源代码][源代码]#

返回满足条件的数组元素。

JAX 实现的 numpy.extract()

参数:
  • condition (ArrayLike) – 条件数组。将被转换为布尔值并展平为1D。

  • arr (ArrayLike) – 要提取的值的数组。将被展平为1D。

  • size (int | None) – 输出可选的静态大小。必须指定,以便 extract 与 JAX 变换(如 jit()vmap())兼容。

  • fill_value (ArrayLike) – 如果指定了 size ,则用此值填充填充条目(默认值:0)。

返回:

提取条目的1D数组。如果指定了 size ,结果将具有形状 (size,) 并使用 fill_value 进行右填充。如果未指定 size ,输出形状将取决于 condition 中 True 条目的数量。

返回类型:

Array

备注

此函数不要求 conditionarr 之间严格形状一致。如果 condition.size > arr.size,则 condition 将被截断,如果 arr.size > condition.size,则 arr 将被截断。

参见

jax.numpy.compress(): extract 的多维版本。

示例

从一维数组中提取值:

>>> x = jnp.array([1, 2, 3, 4, 5, 6])
>>> mask = (x % 2 == 0)
>>> jnp.extract(mask, x)
Array([2, 4, 6], dtype=int32)

在最简单的情况下,这等同于布尔索引:

>>> x[mask]
Array([2, 4, 6], dtype=int32)

在使用 JAX 变换时,您可以传递 size 参数来指定输出的静态形状,同时可以选择传递一个默认值为零的 fill_value 参数:

>>> jnp.extract(mask, x, size=len(x), fill_value=0)
Array([2, 4, 6, 0, 0, 0], dtype=int32)

请注意,与布尔索引不同,extract 不需要数组和条件之间的大小严格一致,并且实际上会将两者截断到最小大小:

>>> short_mask = jnp.array([False, True])
>>> jnp.extract(short_mask, x)
Array([2], dtype=int32)
>>> long_mask = jnp.array([True, False, True, False, False, False, False, False])
>>> jnp.extract(long_mask, x)
Array([1, 3], dtype=int32)