jax.numpy.einsum_path#
- jax.numpy.einsum_path(subscripts: str, /, *operands: ArrayLike, optimize: bool | str | list[tuple[int, ...]] = 'auto') tuple[list[tuple[int, ...]], Any] [源代码][源代码]#
- jax.numpy.einsum_path(arr: ArrayLike, axes: Sequence[Any], /, *operands: ArrayLike | Sequence[Any], optimize: bool | str | list[tuple[int, ...]] = 'auto') tuple[list[tuple[int, ...]], Any]
在不评估einsum的情况下,评估最优的收缩路径。
JAX 实现的
numpy.einsum_path()
。此函数调用了 opt_einsum 包,并利用其优化例程。- 参数:
subscripts – 包含用逗号分隔的轴名称的字符串。
*operands – 对应于下标的多个数组的序列。
optimize – 指定如何优化计算顺序。在 JAX 中,这默认为
"auto"
。其他选项包括True``(与 ``"optimize"
相同),False``(未优化),或任何 ``opt_einsum
支持的字符串,其中包括"optimize"
,"greedy"
,"eager"
等。
- 返回:
包含可以传递给
einsum()
的路径的元组,以及表示此最优路径的可打印对象。
示例
>>> key1, key2, key3 = jax.random.split(jax.random.key(0), 3) >>> x = jax.random.randint(key1, minval=-5, maxval=5, shape=(2, 3)) >>> y = jax.random.randint(key2, minval=-5, maxval=5, shape=(3, 100)) >>> z = jax.random.randint(key3, minval=-5, maxval=5, shape=(100, 5)) >>> path, path_info = jnp.einsum_path("ij,jk,kl", x, y, z, optimize="optimal") >>> print(path) [(1, 2), (0, 1)] >>> print(path_info) Complete contraction: ij,jk,kl->il Naive scaling: 4 Optimized scaling: 3 Naive FLOP count: 9.000e+3 Optimized FLOP count: 3.060e+3 Theoretical speedup: 2.941e+0 Largest intermediate: 1.500e+1 elements -------------------------------------------------------------------------------- scaling BLAS current remaining -------------------------------------------------------------------------------- 3 GEMM kl,jk->lj ij,lj->il 3 GEMM lj,ij->il il->il
在
einsum()
中使用计算的路径:>>> jnp.einsum("ij,jk,kl", x, y, z, optimize=path) Array([[-539, 216, 95, 592, 209], [ 527, 76, 285, -436, -529]], dtype=int32)