jax.lax.dot_general#
- jax.lax.dot_general(lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None)[源代码][源代码]#
通用点积/收缩运算符。
封装了 XLA 的 DotGeneral 操作符。
dot_general
的语义很复杂,但大多数用户不需要直接使用它。相反,你可以使用更高层次的函数,如jax.numpy.dot()
、jax.numpy.matmul()
、jax.numpy.tensordot()
、jax.numpy.einsum()
等,这些函数会在底层构造适当的dot_general
调用。如果你真的想理解dot_general
本身,我们建议阅读 XLA 的 DotGeneral 操作文档。- 参数:
lhs (ArrayLike) – 一个数组
rhs (ArrayLike) – 一个数组
dimension_numbers (DotDimensionNumbers) – 一个由元组组成的元组,其中每个元组包含整数序列,形式为
((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))
precision (PrecisionLike) – 可选。可以是
None
,这意味着后端的默认精度,一个Precision
枚举值(Precision.DEFAULT
、Precision.HIGH
或Precision.HIGHEST
),或者是一个包含两个Precision
枚举的元组,指示lhs`
和rhs
的精度。preferred_element_type (DTypeLike | None) – 可选。可以是
None
,这意味着使用输入类型的默认累积类型,或者是一个数据类型,指示将结果累积到并返回该数据类型的结果。
- 返回:
一个数组,其第一个维度是(共享的)批次维度,接着是
lhs
非收缩/非批次维度,最后是rhs
非收缩/非批次维度。- 返回类型: