jax.numpy.nonzero

目录

jax.numpy.nonzero#

jax.numpy.nonzero(a, *, size=None, fill_value=None)[源代码][源代码]#

返回数组中非零元素的索引。

JAX 实现的 numpy.nonzero()

由于 nonzero 的输出大小取决于数据,该函数与 JIT 和其他变换不兼容。JAX 版本添加了可选的 size 参数,必须在 JAX 的变换中静态指定该参数,以便在 JAX 的变换中使用 jnp.nonzero

参数:
  • a (ArrayLike) – N 维数组。

  • size (int | None) – 可选的静态整数,指定要返回的非零条目的数量。如果非零元素的数量超过指定的 size,则索引将在末尾被截断。如果非零元素的数量少于指定的数量,则索引将用 fill_value 填充,默认为零。

  • fill_value (None | ArrayLike | tuple[ArrayLike, ...]) – 当指定 size 时,可选的填充值。默认为 0。

返回:

长度为 a.ndim 的 JAX 数组元组,包含每个非零值的索引。

返回类型:

tuple[Array, …]

示例

一维数组返回一个长度为1的索引元组:

>>> x = jnp.array([0, 5, 0, 6, 0, 7])
>>> jnp.nonzero(x)
(Array([1, 3, 5], dtype=int32),)

二维数组返回一个长度为2的索引元组:

>>> x = jnp.array([[0, 5, 0],
...                [6, 0, 7]])
>>> jnp.nonzero(x)
(Array([0, 1, 1], dtype=int32), Array([1, 0, 2], dtype=int32))

无论哪种情况,生成的索引元组都可以直接用于提取非零值:

>>> indices = jnp.nonzero(x)
>>> x[indices]
Array([5, 6, 7], dtype=int32)

nonzero 的输出具有动态形状,因为返回的索引数量取决于输入数组的内容。因此,它与 JIT 和其他 JAX 变换不兼容:

>>> x = jnp.array([0, 5, 0, 6, 0, 7])
>>> jax.jit(jnp.nonzero)(x)  
Traceback (most recent call last):
  ...
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[].
The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations.

这可以通过传递一个静态的 size 参数来指定所需的输出形状来解决:

>>> nonzero_jit = jax.jit(jnp.nonzero, static_argnames='size')
>>> nonzero_jit(x, size=3)
(Array([1, 3, 5], dtype=int32),)

如果 size 与实际大小不匹配,结果将会被截断或填充:

>>> nonzero_jit(x, size=2)  # size < 3: indices are truncated
(Array([1, 3], dtype=int32),)
>>> nonzero_jit(x, size=5)  # size > 3: indices are padded with zeros.
(Array([1, 3, 5, 0, 0], dtype=int32),)

你可以使用 fill_value 参数指定填充的自定义值:

>>> nonzero_jit(x, size=5, fill_value=len(x))
(Array([1, 3, 5, 6, 6], dtype=int32),)