jax.ops.segment_prod

jax.ops.segment_prod#

jax.ops.segment_prod(data, segment_ids, num_segments=None, indices_are_sorted=False, unique_indices=False, bucket_size=None, mode=None)[源代码][源代码]#

计算数组中各段内的乘积。

类似于 TensorFlow 的 segment_prod

参数:
  • data (ArrayLike) – 一个包含要减少的值的数组。

  • segment_ids (ArrayLike) – 一个带有整数数据类型的数组,指示 data 的段(沿其主要轴)将被减少。值可以重复,不需要排序。范围 [0, num_segments) 之外的值将被丢弃,并且不会对结果产生贡献。

  • num_segments (int | None) – 可选,一个非负整数值,表示段落的数量。默认设置为支持 segment_ids 中所有索引的最小段落数,计算为 max(segment_ids) + 1。由于 num_segments 决定了输出的尺寸,因此在使用 segment_prod 的 JIT 编译函数中,必须提供一个静态值。

  • indices_are_sorted (bool) – 是否 segment_ids 已知是排序的。

  • unique_indices (bool) – 是否已知 segment_ids 没有重复项。

  • bucket_size (int | None) – 用于将索引分组的桶的大小。segment_prod 分别在每个桶上执行,以提高加法运算的数值稳定性。默认 None 表示不进行分桶。

  • mode (lax.GatherScatterMode | None) – 一个 jax.lax.GatherScatterMode 值,描述如何处理越界索引。默认情况下,范围 [0, num_segments) 之外的值将被丢弃,并且不会对总和产生贡献。

返回:

一个形状为 (num_segments,) + data.shape[1:] 的数组,表示分段乘积。

返回类型:

Array

示例

简单的一维线段产品:

>>> data = jnp.arange(6)
>>> segment_ids = jnp.array([0, 0, 1, 1, 2, 2])
>>> segment_prod(data, segment_ids)
Array([ 0,  6, 20], dtype=int32)

使用 JIT 需要静态 num_segments

>>> from jax import jit
>>> jit(segment_prod, static_argnums=2)(data, segment_ids, 3)
Array([ 0,  6, 20], dtype=int32)