jax.numpy.where#
- jax.numpy.where(condition: ArrayLike, x: Literal[None] = None, y: Literal[None] = None, /, *, size: int | None = None, fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None) tuple[Array, ...] [源代码][源代码]#
- jax.numpy.where(condition: ArrayLike, x: ArrayLike, y: ArrayLike, /, *, size: int | None = None, fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None) Array
- jax.numpy.where(condition: ArrayLike, x: ArrayLike | None = None, y: ArrayLike | None = None, /, *, size: int | None = None, fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None) Array | tuple[Array, ...]
根据条件从两个数组中选择元素。
JAX 实现的
numpy.where()
。备注
当仅提供
condition
时,jnp.where(condition)
等同于jnp.nonzero(condition)
。对于这种情况,请参阅jax.numpy.nonzero()
的文档。下面的文档字符串专注于x
和y
被指定的情况。三参数版本的
jnp.where
会降低到jax.lax.select()
。- 参数:
condition – 布尔数组。当指定
x
和y
时,必须与它们广播兼容。x – arraylike。应与
condition
和y
广播兼容,并与y
类型转换兼容。y – arraylike。应与
condition
和x
广播兼容,并与x
类型转换兼容。size – 整数,仅在
x
和y
为None
时引用。详情请参见jax.numpy.nonzero()
。fill_value – 仅在
x
和y
为None
时引用。详情请参见jax.numpy.nonzero()
。
- 返回:
一个 dtype 为
jnp.result_type(x, y)
的数组,其中condition
为 True 时从x
中取值,condition
为 False 时从y
中取值。如果x
和y
为None
,函数的行为会有所不同;请参阅jax.numpy.nonzero()
以了解返回类型的描述。
备注
在使用
jax.numpy.where()
时,如果x
或y
输入可能为 NaN,则需要特别注意。具体来说,当使用jax.grad`(反向模式微分)计算梯度时,无论 ``condition`()
的值如何,x
或y
中的 NaN 都会传播到梯度中。有关此行为的更多信息和解决方法,请参阅 JAX FAQ。示例
当
x
和y
未提供时,where
的行为等同于jax.numpy.nonzero()
:>>> x = jnp.arange(10) >>> jnp.where(x > 4) (Array([5, 6, 7, 8, 9], dtype=int32),) >>> jnp.nonzero(x > 4) (Array([5, 6, 7, 8, 9], dtype=int32),)
当提供
x
和y
时,where
根据指定的条件在它们之间进行选择:>>> jnp.where(x > 4, x, 0) Array([0, 0, 0, 0, 0, 5, 6, 7, 8, 9], dtype=int32)