jax.numpy.linalg.tensorinv#
- jax.numpy.linalg.tensorinv(a, ind=2)[源代码][源代码]#
计算数组的张量逆。
JAX 实现的
numpy.linalg.tensorinv()
。这计算了与相同
ind
值的tensordot()
操作的逆。- 参数:
a (ArrayLike) – 要被反转的数组。必须满足
prod(a.shape[:ind]) == prod(a.shape[ind:])
ind (int) – 正整数,指定张量积中的索引数。
- 返回:
形状为
(*a.shape[ind:], *a.shape[:ind])
的数组,包含张量a
的逆。- 返回类型:
示例
>>> key = jax.random.key(1337) >>> x = jax.random.normal(key, shape=(2, 2, 4)) >>> xinv = jnp.linalg.tensorinv(x, 2) >>> xinv_x = jnp.linalg.tensordot(xinv, x, axes=2) >>> jnp.allclose(xinv_x, jnp.eye(4), atol=1E-4) Array(True, dtype=bool)