jax.numpy.nonzero#
- jax.numpy.nonzero(a, *, size=None, fill_value=None)[源代码][源代码]#
返回数组中非零元素的索引。
JAX 实现的
numpy.nonzero()
。由于
nonzero
的输出大小取决于数据,该函数与 JIT 和其他变换不兼容。JAX 版本添加了可选的size
参数,必须在 JAX 的变换中静态指定该参数,以便在 JAX 的变换中使用jnp.nonzero
。- 参数:
- 返回:
长度为
a.ndim
的 JAX 数组元组,包含每个非零值的索引。- 返回类型:
示例
一维数组返回一个长度为1的索引元组:
>>> x = jnp.array([0, 5, 0, 6, 0, 7]) >>> jnp.nonzero(x) (Array([1, 3, 5], dtype=int32),)
二维数组返回一个长度为2的索引元组:
>>> x = jnp.array([[0, 5, 0], ... [6, 0, 7]]) >>> jnp.nonzero(x) (Array([0, 1, 1], dtype=int32), Array([1, 0, 2], dtype=int32))
无论哪种情况,生成的索引元组都可以直接用于提取非零值:
>>> indices = jnp.nonzero(x) >>> x[indices] Array([5, 6, 7], dtype=int32)
nonzero
的输出具有动态形状,因为返回的索引数量取决于输入数组的内容。因此,它与 JIT 和其他 JAX 变换不兼容:>>> x = jnp.array([0, 5, 0, 6, 0, 7]) >>> jax.jit(jnp.nonzero)(x) Traceback (most recent call last): ... ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[]. The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations.
这可以通过传递一个静态的
size
参数来指定所需的输出形状来解决:>>> nonzero_jit = jax.jit(jnp.nonzero, static_argnames='size') >>> nonzero_jit(x, size=3) (Array([1, 3, 5], dtype=int32),)
如果
size
与实际大小不匹配,结果将会被截断或填充:>>> nonzero_jit(x, size=2) # size < 3: indices are truncated (Array([1, 3], dtype=int32),) >>> nonzero_jit(x, size=5) # size > 3: indices are padded with zeros. (Array([1, 3, 5, 0, 0], dtype=int32),)
你可以使用
fill_value
参数指定填充的自定义值:>>> nonzero_jit(x, size=5, fill_value=len(x)) (Array([1, 3, 5, 6, 6], dtype=int32),)