jax.numpy.argwhere#
- jax.numpy.argwhere(a, *, size=None, fill_value=None)[源代码][源代码]#
查找非零数组元素的索引
JAX 实现的
numpy.argwhere()
。jnp.argwhere(x)
本质上等同于jnp.column_stack(jnp.nonzero(x))
,并特别处理零维(即标量)输入。由于
argwhere
的输出大小取决于数据,该函数通常与即时编译(JIT)不兼容。JAX 版本添加了可选的size
参数,该参数指定输出前导维度的大小 - 对于非静态操作数,必须静态指定size
以使jnp.argwhere
能够编译。有关size
及其语义的完整讨论,请参见jax.numpy.nonzero()
。- 参数:
a (ArrayLike) – 要查找非零元素的数组
size (int | None) – 可选的整数,用于静态指定预期非零元素的数量。为了在 JAX 变换(如
jax.jit()
)中使用argwhere
,必须指定此项。更多信息请参见jax.numpy.nonzero()
。fill_value (ArrayLike | None) – 指定
size
时使用的填充值的可选数组。更多信息请参见jax.numpy.nonzero()
。
- 返回:
一个形状为
[size, x.ndim]
的二维数组。如果size
没有作为参数指定,它等于x
中非零元素的数量。- 返回类型:
示例
二维数组:
>>> x = jnp.array([[1, 0, 2], ... [0, 3, 0]]) >>> jnp.argwhere(x) Array([[0, 0], [0, 2], [1, 1]], dtype=int32)
使用
jax.numpy.column_stack()
和jax.numpy.nonzero()
进行等效计算:>>> jnp.column_stack(jnp.nonzero(x)) Array([[0, 0], [0, 2], [1, 1]], dtype=int32)
零维(即标量)输入的特殊情况:
>>> jnp.argwhere(1) Array([], shape=(1, 0), dtype=int32) >>> jnp.argwhere(0) Array([], shape=(0, 0), dtype=int32)