jax.ops.segment_sum#
- jax.ops.segment_sum(data, segment_ids, num_segments=None, indices_are_sorted=False, unique_indices=False, bucket_size=None, mode=None)[源代码][源代码]#
计算数组中各段的和。
类似于 TensorFlow 的 segment_sum
- 参数:
data (ArrayLike) – 一个包含待求和值的数组。
segment_ids (ArrayLike) – 一个带有整数数据类型的数组,指示要对其 data (沿其主要轴)进行求和的段。值可以重复且不需要排序。
num_segments (int | None) – 可选,一个非负整数值,表示段落的数量。默认设置为支持
segment_ids
中所有索引的最小段落数,计算为max(segment_ids) + 1
。由于 num_segments 决定了输出的尺寸,因此在使用segment_sum
的 JIT 编译函数中必须提供一个静态值。indices_are_sorted (bool) – 是否
segment_ids
已知是排序的。unique_indices (bool) – 是否已知 segment_ids 没有重复项。
bucket_size (int | None) – 用于将索引分组的桶的大小。
segment_sum
分别在每个桶上执行,以提高加法数值稳定性。默认None
表示不进行分桶。mode (lax.GatherScatterMode | None) – 一个
jax.lax.GatherScatterMode
值,描述如何处理越界索引。默认情况下,范围 [0, num_segments) 之外的值将被丢弃,并且不会对总和产生贡献。
- 返回:
一个形状为
(num_segments,) + data.shape[1:]
的数组,表示分段和。- 返回类型:
示例
简单的一维段求和:
>>> data = jnp.arange(5) >>> segment_ids = jnp.array([0, 0, 1, 1, 2]) >>> segment_sum(data, segment_ids) Array([1, 5, 4], dtype=int32)
使用 JIT 需要静态 num_segments:
>>> from jax import jit >>> jit(segment_sum, static_argnums=2)(data, segment_ids, 3) Array([1, 5, 4], dtype=int32)