jax.experimental.sparse.bcoo_dot_general_sampled

jax.experimental.sparse.bcoo_dot_general_sampled#

jax.experimental.sparse.bcoo_dot_general_sampled(A, B, indices, *, dimension_numbers)[源代码][源代码]#

在给定的稀疏索引处计算输出的收缩操作。

参数:
  • lhs – 一个 ndarray。

  • rhs – 一个 ndarray。

  • indices (Array) – BCOO 索引。

  • dimension_numbers (DotDimensionNumbers) – 一个由元组组成的元组,形式为 ((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))

  • A (Array)

  • B (Array)

返回:

BCOO 数据,包含结果的 ndarray。

返回类型:

Array