jax.numpy.expand_dims#
- jax.numpy.expand_dims(a, axis)[源代码][源代码]#
将长度为1的维度插入数组
JAX 实现的
numpy.expand_dims()
,通过jax.lax.expand_dims()
实现。- 参数:
- 返回:
增加了维度的
a
的副本。- 返回类型:
备注
与
numpy.expand_dims()
不同,jax.numpy.expand_dims()
将返回输入数组的副本而不是视图。然而,在 JIT 下,编译器会在可能的情况下优化掉这些副本,因此在实践中不会影响性能。参见
jax.numpy.squeeze()
: 此操作的逆操作,即移除长度为1的维度。jax.lax.expand_dims()
: 此功能的 XLA 版本。
示例
>>> x = jnp.array([1, 2, 3]) >>> x.shape (3,)
扩展前导维度:
>>> jnp.expand_dims(x, 0) Array([[1, 2, 3]], dtype=int32) >>> _.shape (1, 3)
扩展尾随维度:
>>> jnp.expand_dims(x, 1) Array([[1], [2], [3]], dtype=int32) >>> _.shape (3, 1)
扩展多个维度:
>>> jnp.expand_dims(x, (0, 1, 3)) Array([[[[1], [2], [3]]]], dtype=int32) >>> _.shape (1, 1, 3, 1)
维度也可以通过使用
None
进行索引来更简洁地扩展:>>> x[None] # equivalent to jnp.expand_dims(x, 0) Array([[1, 2, 3]], dtype=int32) >>> x[:, None] # equivalent to jnp.expand_dims(x, 1) Array([[1], [2], [3]], dtype=int32) >>> x[None, None, :, None] # equivalent to jnp.expand_dims(x, (0, 1, 3)) Array([[[[1], [2], [3]]]], dtype=int32)