jax.numpy.setdiff1d

目录

jax.numpy.setdiff1d#

jax.numpy.setdiff1d(ar1, ar2, assume_unique=False, *, size=None, fill_value=None)[源代码][源代码]#

计算两个一维数组的集合差。

JAX 实现的 numpy.setdiff1d()

由于 setdiff1d 的输出大小取决于数据,该函数通常不与 jit() 和其他 JAX 变换兼容。JAX 版本添加了可选的 size 参数,必须在 jnp.setdiff1d 用于此类上下文中时静态指定该参数。

参数:
  • ar1 (ArrayLike) – 要进行差异化的第一个元素数组。

  • ar2 (ArrayLike) – 第二个要进行差异化的元素数组。

  • assume_unique (bool) – 如果为真,假设输入数组包含唯一值。这允许更高效的实现,但如果 assume_unique 为真且输入数组包含重复项,则行为未定义。默认值:False。

  • size (int | None) – 如果指定,则仅返回前 size 个排序后的元素。如果元素数量少于 size 所指示的数量,返回值将用 fill_value 填充。

  • fill_value (ArrayLike | None) – 当指定 size 并且元素数量少于指定数量时,用 fill_value 填充剩余的条目。默认为最小值。

返回:

ar1 中未包含在 ar2 中的元素。

返回类型:

an array containing the set difference of elements in the input array

参见

示例

计算两个数组的集合差集:

>>> ar1 = jnp.array([1, 2, 3, 4])
>>> ar2 = jnp.array([3, 4, 5, 6])
>>> jnp.setdiff1d(ar1, ar2)
Array([1, 2], dtype=int32)

因为输出形状是动态的,这将在 jit() 和其他变换下失败:

>>> jax.jit(jnp.setdiff1d)(ar1, ar2)  
Traceback (most recent call last):
   ...
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[4].
The error occurred while tracing the function setdiff1d at /Users/vanderplas/github/google/jax/jax/_src/numpy/setops.py:64 for jit. This concrete value was not available in Python because it depends on the value of the argument ar1.

为了确保静态已知的输出形状,你可以传递一个静态的 size 参数:

>>> jit_setdiff1d = jax.jit(jnp.setdiff1d, static_argnames=['size'])
>>> jit_setdiff1d(ar1, ar2, size=2)
Array([1, 2], dtype=int32)

如果 size 太小,差异将被截断:

>>> jit_setdiff1d(ar1, ar2, size=1)
Array([1], dtype=int32)

如果 size 过大,则输出会用 fill_value 填充:

>>> jit_setdiff1d(ar1, ar2, size=4, fill_value=0)
Array([1, 2, 0, 0], dtype=int32)