jax.numpy.argwhere

目录

jax.numpy.argwhere#

jax.numpy.argwhere(a, *, size=None, fill_value=None)[源代码][源代码]#

查找非零数组元素的索引

JAX 实现的 numpy.argwhere()

jnp.argwhere(x) 本质上等同于 jnp.column_stack(jnp.nonzero(x)),并特别处理零维(即标量)输入。

由于 argwhere 的输出大小取决于数据,该函数通常与即时编译(JIT)不兼容。JAX 版本添加了可选的 size 参数,该参数指定输出前导维度的大小 - 对于非静态操作数,必须静态指定 size 以使 jnp.argwhere 能够编译。有关 size 及其语义的完整讨论,请参见 jax.numpy.nonzero()

参数:
  • a (ArrayLike) – 要查找非零元素的数组

  • size (int | None) – 可选的整数,用于静态指定预期非零元素的数量。为了在 JAX 变换(如 jax.jit())中使用 argwhere,必须指定此项。更多信息请参见 jax.numpy.nonzero()

  • fill_value (ArrayLike | None) – 指定 size 时使用的填充值的可选数组。更多信息请参见 jax.numpy.nonzero()

返回:

一个形状为 [size, x.ndim] 的二维数组。如果 size 没有作为参数指定,它等于 x 中非零元素的数量。

返回类型:

Array

示例

二维数组:

>>> x = jnp.array([[1, 0, 2],
...                [0, 3, 0]])
>>> jnp.argwhere(x)
Array([[0, 0],
       [0, 2],
       [1, 1]], dtype=int32)

使用 jax.numpy.column_stack()jax.numpy.nonzero() 进行等效计算:

>>> jnp.column_stack(jnp.nonzero(x))
Array([[0, 0],
       [0, 2],
       [1, 1]], dtype=int32)

零维(即标量)输入的特殊情况:

>>> jnp.argwhere(1)
Array([], shape=(1, 0), dtype=int32)
>>> jnp.argwhere(0)
Array([], shape=(0, 0), dtype=int32)