jax.numpy.linalg.tensorsolve

目录

jax.numpy.linalg.tensorsolve#

jax.numpy.linalg.tensorsolve(a, b, axes=None)[源代码][源代码]#

求解张量方程 a x = b 中的 x。

JAX 实现的 numpy.linalg.tensorsolve()

参数:
  • a (ArrayLike) – 输入数组。通过 axes 重新排序后(见下文),形状必须为 (*b.shape, *x.shape)

  • b (ArrayLike) – 右侧数组。

  • axes (tuple[int, ...] | None) – 可选的元组,指定应移动到 a 末尾的轴

返回:

数组 x 使得在重新排序 a 的轴之后,tensordot(a, x, x.ndim) 等价于 b

返回类型:

Array

示例

>>> key1, key2 = jax.random.split(jax.random.key(8675309))
>>> a = jax.random.normal(key1, shape=(2, 2, 4))
>>> b = jax.random.normal(key2, shape=(2, 2))
>>> x = jnp.linalg.tensorsolve(a, b)
>>> x.shape
(4,)

现在展示 x 可以用来使用 tensordot() 重建 b

>>> b_reconstructed = jnp.linalg.tensordot(a, x, axes=x.ndim)
>>> jnp.allclose(b, b_reconstructed)
Array(True, dtype=bool)