jax.lax.dot_general

目录

jax.lax.dot_general#

jax.lax.dot_general(lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None)[源代码][源代码]#

通用点积/收缩运算符。

封装了 XLA 的 DotGeneral 操作符。

dot_general 的语义很复杂,但大多数用户不需要直接使用它。相反,你可以使用更高层次的函数,如 jax.numpy.dot()jax.numpy.matmul()jax.numpy.tensordot()jax.numpy.einsum() 等,这些函数会在底层构造适当的 dot_general 调用。如果你真的想理解 dot_general 本身,我们建议阅读 XLA 的 DotGeneral 操作文档。

参数:
  • lhs (ArrayLike) – 一个数组

  • rhs (ArrayLike) – 一个数组

  • dimension_numbers (DotDimensionNumbers) – 一个由元组组成的元组,其中每个元组包含整数序列,形式为 ((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))

  • precision (PrecisionLike) – 可选。可以是 None,这意味着后端的默认精度,一个 Precision 枚举值(Precision.DEFAULTPrecision.HIGHPrecision.HIGHEST),或者是一个包含两个 Precision 枚举的元组,指示 lhs`rhs 的精度。

  • preferred_element_type (DTypeLike | None) – 可选。可以是 None,这意味着使用输入类型的默认累积类型,或者是一个数据类型,指示将结果累积到并返回该数据类型的结果。

返回:

一个数组,其第一个维度是(共享的)批次维度,接着是 lhs 非收缩/非批次维度,最后是 rhs 非收缩/非批次维度。

返回类型:

Array