jax.numpy.squeeze#
- jax.numpy.squeeze(a, axis=None)[源代码][源代码]#
从数组中移除一个或多个长度为1的轴
JAX 实现的
numpy.sqeeze()
,通过jax.lax.squeeze()
实现。- 参数:
- 返回:
移除长度为1的轴后的
a
的副本。- 返回类型:
备注
与
numpy.squeeze()
不同,jax.numpy.squeeze()
将返回输入数组的副本而不是视图。然而,在 JIT 下,编译器会在可能的情况下优化掉这些副本,因此在实践中不会影响性能。参见
jax.numpy.expand_dims()
:squeeze
的逆操作:添加长度为1的维度。jax.Array.squeeze()
: 通过数组方法实现的等效功能。jax.lax.squeeze()
: 等效的 XLA API。jax.numpy.ravel()
: 将数组展平为1D形状。jax.numpy.reshape()
: 通用数组重塑。
示例
>>> x = jnp.array([[[0]], [[1]], [[2]]]) >>> x.shape (3, 1, 1)
压缩所有长度为1的维度:
>>> jnp.squeeze(x) Array([0, 1, 2], dtype=int32) >>> _.shape (3,)
显式指定轴时的等效写法:
>>> jnp.squeeze(x, axis=(1, 2)) Array([0, 1, 2], dtype=int32)
尝试压缩非单位轴会导致错误:
>>> jnp.squeeze(x, axis=0) Traceback (most recent call last): ... ValueError: cannot select an axis to squeeze out which has size not equal to one, got shape=(3, 1, 1) and dimensions=(0,)
为了方便,此功能也可以通过
jax.Array.squeeze()
方法来使用:>>> x.squeeze() Array([0, 1, 2], dtype=int32)