jax.numpy.linalg.tensorsolve#
- jax.numpy.linalg.tensorsolve(a, b, axes=None)[源代码][源代码]#
求解张量方程 a x = b 中的 x。
JAX 实现的
numpy.linalg.tensorsolve()
。- 参数:
- 返回:
数组 x 使得在重新排序
a
的轴之后,tensordot(a, x, x.ndim)
等价于b
。- 返回类型:
示例
>>> key1, key2 = jax.random.split(jax.random.key(8675309)) >>> a = jax.random.normal(key1, shape=(2, 2, 4)) >>> b = jax.random.normal(key2, shape=(2, 2)) >>> x = jnp.linalg.tensorsolve(a, b) >>> x.shape (4,)
现在展示
x
可以用来使用tensordot()
重建b
:>>> b_reconstructed = jnp.linalg.tensordot(a, x, axes=x.ndim) >>> jnp.allclose(b, b_reconstructed) Array(True, dtype=bool)