jax.numpy.linalg.multi_dot

目录

jax.numpy.linalg.multi_dot#

jax.numpy.linalg.multi_dot(arrays, *, precision=None)[源代码][源代码]#

高效计算数组序列之间的矩阵乘积。

JAX 实现的 numpy.linalg.multi_dot()

JAX 内部使用 opt_einsum 库来计算最有效的操作顺序。

参数:
  • arrays (Sequence[ArrayLike]) – 数组序列。所有数组必须是二维的,除了第一个和最后一个可以是二维的。

  • precision (PrecisionLike) – None``(默认),这意味着后端的默认精度,一个 :class:`~jax.lax.Precision` 枚举值(``Precision.DEFAULTPrecision.HIGHPrecision.HIGHEST)。

返回:

一个数组,表示与 reduce(jnp.matmul, arrays) 等效的内容,但以最佳顺序进行计算。

返回类型:

Array

这个函数存在的原因是,计算一系列矩阵乘法操作的成本可能会因操作的评估顺序而有很大差异。对于单个矩阵乘法,计算矩阵乘积所需的浮点运算次数(flops)可以这样近似:

>>> def approx_flops(x, y):
...   # for 2D x and y, with x.shape[1] == y.shape[0]
...   return 2 * x.shape[0] * x.shape[1] * y.shape[1]

假设我们有三个矩阵,我们想按顺序相乘:

>>> key1, key2, key3 = jax.random.split(jax.random.key(0), 3)
>>> x = jax.random.normal(key1, shape=(200, 5))
>>> y = jax.random.normal(key2, shape=(5, 100))
>>> z = jax.random.normal(key3, shape=(100, 10))

由于矩阵乘法的结合性,我们可能以两种顺序来计算乘积 x @ y @ z,并且两者在浮点精度内产生等效的输出:

>>> result1 = (x @ y) @ z
>>> result2 = x @ (y @ z)
>>> jnp.allclose(result1, result2, atol=1E-4)
Array(True, dtype=bool)

但这些的计算成本差异很大:

>>> print("(x @ y) @ z flops:", approx_flops(x, y) + approx_flops(x @ y, z))
(x @ y) @ z flops: 600000
>>> print("x @ (y @ z) flops:", approx_flops(y, z) + approx_flops(x, y @ z))
x @ (y @ z) flops: 30000

第二种方法在估计的浮点运算次数方面大约高效20倍!

multi_dot 是一个函数,它将自动选择此类问题的最快计算路径:

>>> result3 = jnp.linalg.multi_dot([x, y, z])
>>> jnp.allclose(result1, result3, atol=1E-4)
Array(True, dtype=bool)

我们可以使用 JAX 的 提前降低 工具来估计每种方法的总浮点运算次数,并确认 multi_dot 选择了更高效的选择:

>>> jax.jit(lambda x, y, z: (x @ y) @ z).lower(x, y, z).cost_analysis()['flops']
600000.0
>>> jax.jit(lambda x, y, z: x @ (y @ z)).lower(x, y, z).cost_analysis()['flops']
30000.0
>>> jax.jit(jnp.linalg.multi_dot).lower([x, y, z]).cost_analysis()['flops']
30000.0