jax.lax.psum_散布#
- jax.lax.psum_scatter(x, axis_name, *, scatter_dimension=0, axis_index_groups=None, tiled=False)[源代码][源代码]#
类似于
psum(x, axis_name)
,但每个设备只保留结果的一部分。例如,
psum_scatter(x, axis_name, scatter_dimension=0, tiled=False)
计算的结果与psum(x, axis_name)[axis_index(axis_name)]
相同,但效率更高。因此,psum
的结果沿着映射的轴分散。计算
psum(x, axis_name)
的一种高效算法是执行psum_scatter
紧接着执行all_gather
,本质上是在计算all_gather(psum_scatter(x, axis_name))
。因此,我们可以将psum_scatter
视为psum
的“前半部分”。- 参数:
x – 带有映射轴名为
axis_name
的数组。axis_name – 用于命名映射轴的可哈希 Python 对象(更多详情请参阅
jax.pmap()
文档)。scatter_dimension – 一个位置轴,沿着
axis_name
的 all-reduce 结果将被分散到该轴上。axis_index_groups – 包含轴索引的可选整数列表列表。例如,对于大小为4的轴,
axis_index_groups=[[0, 1], [2, 3]]
将对前两个和后两个轴索引运行reduce-scatter。组必须完全覆盖所有轴索引一次,并且所有组的大小必须相同。tiled – 表示是否使用保持秩的 ‘tiled’ 行为。当
False``(默认值)时,``scatter_dimension
中的维度大小必须与axis_name
轴的大小匹配(或者如果给出了axis_index_groups
,则为组大小)。在沿着scatter_dimension
分散所有减少的结果后,输出通过移除scatter_dimension
被压缩,因此结果的秩低于输入。当True
时,scatter_dimension
中的维度大小必须能被axis_name
轴的大小(或者如果给出了axis_index_groups
,则为组大小)整除,并且scatter_dimension
轴被保留(因此结果与输入具有相同的秩)。
- 返回:
与
x
形状相似的数组,除了在scatter_dimension
位置的维度大小被axis_name
轴的大小除(当tiled=True
时),或者在scatter_dimension
位置的维度被消除(当tiled=False
时)。
例如,如果有4个可用的XLA设备:
>>> x = np.arange(16).reshape(4, 4) >>> print(x) [[ 0 1 2 3] [ 4 5 6 7] [ 8 9 10 11] [12 13 14 15]] >>> y = jax.pmap(lambda x: jax.lax.psum_scatter(x, 'i'), axis_name='i')(x) >>> print(y) [24 28 32 36]
如果使用 tiled:
>>> y = jax.pmap(lambda x: jax.lax.psum_scatter(x, 'i', tiled=True), axis_name='i')(x) >>> print(y) [[24] [28] [32] [36]]
使用 axis_index_groups 的示例:
>>> def f(x): ... return jax.lax.psum_scatter( ... x, 'i', axis_index_groups=[[0, 2], [3, 1]], tiled=True) >>> y = jax.pmap(f, axis_name='i')(x) >>> print(y) [[ 8 10] [20 22] [12 14] [16 18]]