jax.numpy.transpose#
- jax.numpy.transpose(a, axes=None)[源代码][源代码]#
返回一个 N 维数组的转置版本。
JAX 实现
numpy.transpose()
,基于jax.lax.transpose()
实现。- 参数:
a (ArrayLike) – 输入数组
axes (Sequence[int] | None) – 可选地使用长度为 a.ndim 的整数序列
i
来指定排列,满足0 <= i < a.ndim
。默认为range(a.ndim)[::-1]
,即反转所有轴的顺序。
- 返回:
数组的转置副本。
- 返回类型:
参见
jax.Array.transpose()
: 通过Array
方法实现的等效函数。jax.Array.T
: 通过Array
属性实现的等效功能。jax.numpy.matrix_transpose()
:转置数组的最后两个轴。这适用于处理批量二维矩阵。jax.numpy.swapaxes()
: 交换数组中的任意两个轴。jax.numpy.moveaxis()
: 将轴移动到数组中的另一个位置。
备注
与
numpy.transpose()
不同,jax.numpy.transpose()
将返回输入数组的一个副本而不是视图。然而,在 JIT 下,编译器会在可能的情况下优化掉这些副本,因此在实践中不会影响性能。示例
对于一维数组,转置是恒等变换:
>>> x = jnp.array([1, 2, 3, 4]) >>> jnp.transpose(x) Array([1, 2, 3, 4], dtype=int32)
对于一个二维数组,转置是一个矩阵转置:
>>> x = jnp.array([[1, 2], ... [3, 4]]) >>> jnp.transpose(x) Array([[1, 3], [2, 4]], dtype=int32)
对于一个N维数组,转置会反转轴的顺序:
>>> x = jnp.zeros(shape=(3, 4, 5)) >>> jnp.transpose(x).shape (5, 4, 3)
可以通过指定
axes
参数来改变这种默认行为:>>> jnp.transpose(x, (0, 2, 1)).shape (3, 5, 4)
由于交换最后两个轴是一个常见的操作,可以通过其自身的API来完成,即
jax.numpy.matrix_transpose()
:>>> jnp.matrix_transpose(x).shape (3, 5, 4)
为了方便,转置也可以使用
jax.Array.transpose()
方法或jax.Array.T
属性来执行:>>> x = jnp.array([[1, 2], ... [3, 4]]) >>> x.transpose() Array([[1, 3], [2, 4]], dtype=int32) >>> x.T Array([[1, 3], [2, 4]], dtype=int32)