jax.numpy.einsum_path

目录

jax.numpy.einsum_path#

jax.numpy.einsum_path(subscripts: str, /, *operands: ArrayLike, optimize: bool | str | list[tuple[int, ...]] = 'auto') tuple[list[tuple[int, ...]], Any][源代码][源代码]#
jax.numpy.einsum_path(arr: ArrayLike, axes: Sequence[Any], /, *operands: ArrayLike | Sequence[Any], optimize: bool | str | list[tuple[int, ...]] = 'auto') tuple[list[tuple[int, ...]], Any]

在不评估einsum的情况下,评估最优的收缩路径。

JAX 实现的 numpy.einsum_path()。此函数调用了 opt_einsum 包,并利用其优化例程。

参数:
  • subscripts – 包含用逗号分隔的轴名称的字符串。

  • *operands – 对应于下标的多个数组的序列。

  • optimize – 指定如何优化计算顺序。在 JAX 中,这默认为 "auto"。其他选项包括 True``(与 ``"optimize" 相同),False``(未优化),或任何 ``opt_einsum 支持的字符串,其中包括 "optimize""greedy""eager" 等。

返回:

包含可以传递给 einsum() 的路径的元组,以及表示此最优路径的可打印对象。

示例

>>> key1, key2, key3 = jax.random.split(jax.random.key(0), 3)
>>> x = jax.random.randint(key1, minval=-5, maxval=5, shape=(2, 3))
>>> y = jax.random.randint(key2, minval=-5, maxval=5, shape=(3, 100))
>>> z = jax.random.randint(key3, minval=-5, maxval=5, shape=(100, 5))
>>> path, path_info = jnp.einsum_path("ij,jk,kl", x, y, z, optimize="optimal")
>>> print(path)
[(1, 2), (0, 1)]
>>> print(path_info)
      Complete contraction:  ij,jk,kl->il
            Naive scaling:  4
        Optimized scaling:  3
          Naive FLOP count:  9.000e+3
      Optimized FLOP count:  3.060e+3
      Theoretical speedup:  2.941e+0
      Largest intermediate:  1.500e+1 elements
    --------------------------------------------------------------------------------
    scaling        BLAS                current                             remaining
    --------------------------------------------------------------------------------
      3           GEMM              kl,jk->lj                             ij,lj->il
      3           GEMM              lj,ij->il                                il->il

einsum() 中使用计算的路径:

>>> jnp.einsum("ij,jk,kl", x, y, z, optimize=path)
Array([[-539,  216,   95,  592,  209],
       [ 527,   76,  285, -436, -529]], dtype=int32)