jax.numpy.linalg.矩阵转置#
- jax.numpy.linalg.matrix_transpose(x, /)[源代码][源代码]#
转置矩阵或矩阵堆栈。
JAX 实现的
numpy.linalg.matrix_transpose()
。- 参数:
x (ArrayLike) – 形状为
(..., M, N)
的数组- 返回:
形状为
(..., N, M)
的数组,包含x
的矩阵转置。- 返回类型:
参见
jax.numpy.transpose()
: 更通用的转置操作。示例
单个矩阵的转置:
>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> jnp.linalg.matrix_transpose(x) Array([[1, 4], [2, 5], [3, 6]], dtype=int32)
矩阵堆栈的转置:
>>> x = jnp.array([[[1, 2], ... [3, 4]], ... [[5, 6], ... [7, 8]]]) >>> jnp.linalg.matrix_transpose(x) Array([[[1, 3], [2, 4]], [[5, 7], [6, 8]]], dtype=int32)
为了方便,可以通过 JAX 数组对象的
mT
属性来完成相同的计算:>>> x.mT Array([[[1, 3], [2, 4]], [[5, 7], [6, 8]]], dtype=int32)