jax.numpy.linalg.tensorinv

目录

jax.numpy.linalg.tensorinv#

jax.numpy.linalg.tensorinv(a, ind=2)[源代码][源代码]#

计算数组的张量逆。

JAX 实现的 numpy.linalg.tensorinv()

这计算了与相同 ind 值的 tensordot() 操作的逆。

参数:
  • a (ArrayLike) – 要被反转的数组。必须满足 prod(a.shape[:ind]) == prod(a.shape[ind:])

  • ind (int) – 正整数,指定张量积中的索引数。

返回:

形状为 (*a.shape[ind:], *a.shape[:ind]) 的数组,包含张量 a 的逆。

返回类型:

Array

示例

>>> key = jax.random.key(1337)
>>> x = jax.random.normal(key, shape=(2, 2, 4))
>>> xinv = jnp.linalg.tensorinv(x, 2)
>>> xinv_x = jnp.linalg.tensordot(xinv, x, axes=2)
>>> jnp.allclose(xinv_x, jnp.eye(4), atol=1E-4)
Array(True, dtype=bool)