jax.numpy.transpose

目录

jax.numpy.transpose#

jax.numpy.transpose(a, axes=None)[源代码][源代码]#

返回一个 N 维数组的转置版本。

JAX 实现 numpy.transpose(),基于 jax.lax.transpose() 实现。

参数:
  • a (ArrayLike) – 输入数组

  • axes (Sequence[int] | None) – 可选地使用长度为 a.ndim 的整数序列 i 来指定排列,满足 0 <= i < a.ndim。默认为 range(a.ndim)[::-1],即反转所有轴的顺序。

返回:

数组的转置副本。

返回类型:

Array

参见

备注

numpy.transpose() 不同,jax.numpy.transpose() 将返回输入数组的一个副本而不是视图。然而,在 JIT 下,编译器会在可能的情况下优化掉这些副本,因此在实践中不会影响性能。

示例

对于一维数组,转置是恒等变换:

>>> x = jnp.array([1, 2, 3, 4])
>>> jnp.transpose(x)
Array([1, 2, 3, 4], dtype=int32)

对于一个二维数组,转置是一个矩阵转置:

>>> x = jnp.array([[1, 2],
...                [3, 4]])
>>> jnp.transpose(x)
Array([[1, 3],
       [2, 4]], dtype=int32)

对于一个N维数组,转置会反转轴的顺序:

>>> x = jnp.zeros(shape=(3, 4, 5))
>>> jnp.transpose(x).shape
(5, 4, 3)

可以通过指定 axes 参数来改变这种默认行为:

>>> jnp.transpose(x, (0, 2, 1)).shape
(3, 5, 4)

由于交换最后两个轴是一个常见的操作,可以通过其自身的API来完成,即 jax.numpy.matrix_transpose():

>>> jnp.matrix_transpose(x).shape
(3, 5, 4)

为了方便,转置也可以使用 jax.Array.transpose() 方法或 jax.Array.T 属性来执行:

>>> x = jnp.array([[1, 2],
...                [3, 4]])
>>> x.transpose()
Array([[1, 3],
       [2, 4]], dtype=int32)
>>> x.T
Array([[1, 3],
       [2, 4]], dtype=int32)