jax.numpy.extract#
- jax.numpy.extract(condition, arr, *, size=None, fill_value=0)[源代码][源代码]#
返回满足条件的数组元素。
JAX 实现的
numpy.extract()
。- 参数:
- 返回:
提取条目的1D数组。如果指定了
size
,结果将具有形状(size,)
并使用fill_value
进行右填充。如果未指定size
,输出形状将取决于condition
中 True 条目的数量。- 返回类型:
备注
此函数不要求
condition
和arr
之间严格形状一致。如果condition.size > arr.size
,则condition
将被截断,如果arr.size > condition.size
,则arr
将被截断。参见
jax.numpy.compress()
:extract
的多维版本。示例
从一维数组中提取值:
>>> x = jnp.array([1, 2, 3, 4, 5, 6]) >>> mask = (x % 2 == 0) >>> jnp.extract(mask, x) Array([2, 4, 6], dtype=int32)
在最简单的情况下,这等同于布尔索引:
>>> x[mask] Array([2, 4, 6], dtype=int32)
在使用 JAX 变换时,您可以传递
size
参数来指定输出的静态形状,同时可以选择传递一个默认值为零的fill_value
参数:>>> jnp.extract(mask, x, size=len(x), fill_value=0) Array([2, 4, 6, 0, 0, 0], dtype=int32)
请注意,与布尔索引不同,
extract
不需要数组和条件之间的大小严格一致,并且实际上会将两者截断到最小大小:>>> short_mask = jnp.array([False, True]) >>> jnp.extract(short_mask, x) Array([2], dtype=int32) >>> long_mask = jnp.array([True, False, True, False, False, False, False, False]) >>> jnp.extract(long_mask, x) Array([1, 3], dtype=int32)