jax.lax.gather

目录

jax.lax.gather#

jax.lax.gather(operand, start_indices, dimension_numbers, slice_sizes, *, unique_indices=False, indices_are_sorted=False, mode=None, fill_value=None)[源代码][源代码]#

收集操作符。

封装了 XLA 的 Gather 操作符

gather 的语义很复杂,其 API 可能在将来发生变化。对于大多数用例,您应该优先使用 Numpy 风格的索引 <https://numpy.org/doc/stable/reference/arrays.indexing.html>`_(例如,`x[:, (1,4,7), …]),而不是直接使用 gather

参数:
  • operand (ArrayLike) – 应该从中取片的数组

  • start_indices (ArrayLike) – 应取切片的位置索引

  • dimension_numbers (GatherDimensionNumbers) – 一个 lax.GatherDimensionNumbers 对象,描述了 operandstart_indices 和输出维度之间的关系。

  • slice_sizes (Shape) – 每个切片的大小。必须是一个非负整数的序列,长度等于 ndim(operand)

  • indices_are_sorted (bool) – 是否已知 indices 已排序。如果为真,可能会提高某些后端上的性能。

  • unique_indices (bool) – 从 operand 收集的元素是否保证不相互重叠。如果为 True,这可能会提高某些后端的性能。JAX 不会检查此承诺:如果元素重叠,行为是未定义的。

  • mode (str | GatherScatterMode | None) – 如何处理越界索引:当设置为 'clip' 时,索引会被限制在边界内,使得切片在边界内;当设置为 'fill''drop' 时,gather 会返回一个充满 fill_value 的切片,用于受影响的切片。当设置为 'promise_in_bounds' 时,越界索引的行为是实现定义的。

  • fill_value – 当 mode'fill' 时,用于返回超出边界切片的填充值。否则忽略。默认值为:对于不精确类型为 NaN,对于有符号类型为最大负值,对于无符号类型为最大正值,对于布尔类型为 True

返回:

包含 gather 输出的数组。

返回类型:

Array