jax.numpy.intersect1d

目录

jax.numpy.intersect1d#

jax.numpy.intersect1d(ar1, ar2, assume_unique=False, return_indices=False, *, size=None, fill_value=None)[源代码][源代码]#

计算两个一维数组的集合交集。

JAX 实现的 numpy.intersect1d()

由于 intersect1d 的输出大小取决于数据,该函数通常不兼容 jit() 和其他 JAX 变换。JAX 版本添加了可选的 size 参数,必须在 jnp.intersect1d 用于此类上下文时静态指定该参数。

参数:
  • ar1 (ArrayLike) – 第一个要相交的值数组。

  • ar2 (ArrayLike) – 第二个要相交的值数组。

  • assume_unique (bool) – 如果为真,假设输入数组包含唯一值。这允许更高效的实现,但如果 assume_unique 为真且输入数组包含重复项,则行为未定义。默认值:False。

  • return_indices (bool) – 如果为 True,返回索引数组,指定交集值在输入数组中首次出现的位置。

  • size (int | None) – 如果指定,则仅返回前 size 个排序后的元素。如果元素数量少于 size 指示的数量,返回值将用 fill_value 填充,返回的索引将用超出范围的索引填充。

  • fill_value (ArrayLike | None) – 当指定了 size 并且元素数量少于指定数量时,用 fill_value 填充剩余的条目。默认为交集中的最小值。

返回:

一个数组 intersection,或者如果 return_indices=True,则是一个数组元组 (intersection, ar1_indices, ar2_indices)。返回的值是 - intersection: 一个包含在 ar1ar2 中都出现的每个值的 1D 数组。 - ar1_indices: (如果 return_indices=True 则返回) 一个形状为 intersection.shape 的数组,包含 intersection 中值在展平的 ar1 中的索引。对于 1D 输入,intersection 等同于 ar1[ar1_indices]。 - ar2_indices: (如果 return_indices=True 则返回) 一个形状为 intersection.shape 的数组,包含 intersection 中值在展平的 ar2 中的索引。对于 1D 输入,intersection 等同于 ar2[ar2_indices]

返回类型:

Array | tuple[Array, Array, Array]

参见

示例

>>> ar1 = jnp.array([1, 2, 3, 4])
>>> ar2 = jnp.array([3, 4, 5, 6])
>>> jnp.intersect1d(ar1, ar2)
Array([3, 4], dtype=int32)

计算与索引的交集:

>>> intersection, ar1_indices, ar2_indices = jnp.intersect1d(ar1, ar2, return_indices=True)
>>> intersection
Array([3, 4], dtype=int32)

ar1_indices 给出了 ar1 中相交值的索引:

>>> ar1_indices
Array([2, 3], dtype=int32)
>>> jnp.all(intersection == ar1[ar1_indices])
Array(True, dtype=bool)

ar2_indices 给出了在 ar2 中相交值的索引:

>>> ar2_indices
Array([0, 1], dtype=int32)
>>> jnp.all(intersection == ar2[ar2_indices])
Array(True, dtype=bool)