jax.numpy.intersect1d#
- jax.numpy.intersect1d(ar1, ar2, assume_unique=False, return_indices=False, *, size=None, fill_value=None)[源代码][源代码]#
计算两个一维数组的集合交集。
JAX 实现的
numpy.intersect1d()
。由于
intersect1d
的输出大小取决于数据,该函数通常不兼容jit()
和其他 JAX 变换。JAX 版本添加了可选的size
参数,必须在jnp.intersect1d
用于此类上下文时静态指定该参数。- 参数:
ar1 (ArrayLike) – 第一个要相交的值数组。
ar2 (ArrayLike) – 第二个要相交的值数组。
assume_unique (bool) – 如果为真,假设输入数组包含唯一值。这允许更高效的实现,但如果
assume_unique
为真且输入数组包含重复项,则行为未定义。默认值:False。return_indices (bool) – 如果为 True,返回索引数组,指定交集值在输入数组中首次出现的位置。
size (int | None) – 如果指定,则仅返回前
size
个排序后的元素。如果元素数量少于size
指示的数量,返回值将用fill_value
填充,返回的索引将用超出范围的索引填充。fill_value (ArrayLike | None) – 当指定了
size
并且元素数量少于指定数量时,用fill_value
填充剩余的条目。默认为交集中的最小值。
- 返回:
一个数组
intersection
,或者如果return_indices=True
,则是一个数组元组(intersection, ar1_indices, ar2_indices)
。返回的值是 -intersection
: 一个包含在ar1
和ar2
中都出现的每个值的 1D 数组。 -ar1_indices
: (如果 return_indices=True 则返回) 一个形状为intersection.shape
的数组,包含intersection
中值在展平的ar1
中的索引。对于 1D 输入,intersection
等同于ar1[ar1_indices]
。 -ar2_indices
: (如果 return_indices=True 则返回) 一个形状为intersection.shape
的数组,包含intersection
中值在展平的ar2
中的索引。对于 1D 输入,intersection
等同于ar2[ar2_indices]
。- 返回类型:
参见
jax.numpy.union1d()
: 两个一维数组的集合并集。jax.numpy.setxor1d()
: 两个一维数组的集合异或。jax.numpy.setdiff1d()
: 两个一维数组的集合差。
示例
>>> ar1 = jnp.array([1, 2, 3, 4]) >>> ar2 = jnp.array([3, 4, 5, 6]) >>> jnp.intersect1d(ar1, ar2) Array([3, 4], dtype=int32)
计算与索引的交集:
>>> intersection, ar1_indices, ar2_indices = jnp.intersect1d(ar1, ar2, return_indices=True) >>> intersection Array([3, 4], dtype=int32)
ar1_indices
给出了ar1
中相交值的索引:>>> ar1_indices Array([2, 3], dtype=int32) >>> jnp.all(intersection == ar1[ar1_indices]) Array(True, dtype=bool)
ar2_indices
给出了在ar2
中相交值的索引:>>> ar2_indices Array([0, 1], dtype=int32) >>> jnp.all(intersection == ar2[ar2_indices]) Array(True, dtype=bool)