jax.lax.psum_散布

jax.lax.psum_散布#

jax.lax.psum_scatter(x, axis_name, *, scatter_dimension=0, axis_index_groups=None, tiled=False)[源代码][源代码]#

类似于 psum(x, axis_name),但每个设备只保留结果的一部分。

例如,psum_scatter(x, axis_name, scatter_dimension=0, tiled=False) 计算的结果与 psum(x, axis_name)[axis_index(axis_name)] 相同,但效率更高。因此,psum 的结果沿着映射的轴分散。

计算 psum(x, axis_name) 的一种高效算法是执行 psum_scatter 紧接着执行 all_gather,本质上是在计算 all_gather(psum_scatter(x, axis_name))。因此,我们可以将 psum_scatter 视为 psum 的“前半部分”。

参数:
  • x – 带有映射轴名为 axis_name 的数组。

  • axis_name – 用于命名映射轴的可哈希 Python 对象(更多详情请参阅 jax.pmap() 文档)。

  • scatter_dimension – 一个位置轴,沿着 axis_name 的 all-reduce 结果将被分散到该轴上。

  • axis_index_groups – 包含轴索引的可选整数列表列表。例如,对于大小为4的轴,axis_index_groups=[[0, 1], [2, 3]] 将对前两个和后两个轴索引运行reduce-scatter。组必须完全覆盖所有轴索引一次,并且所有组的大小必须相同。

  • tiled – 表示是否使用保持秩的 ‘tiled’ 行为。当 False``(默认值)时,``scatter_dimension 中的维度大小必须与 axis_name 轴的大小匹配(或者如果给出了 axis_index_groups,则为组大小)。在沿着 scatter_dimension 分散所有减少的结果后,输出通过移除 scatter_dimension 被压缩,因此结果的秩低于输入。当 True 时,scatter_dimension 中的维度大小必须能被 axis_name 轴的大小(或者如果给出了 axis_index_groups,则为组大小)整除,并且 scatter_dimension 轴被保留(因此结果与输入具有相同的秩)。

返回:

x 形状相似的数组,除了在 scatter_dimension 位置的维度大小被 axis_name 轴的大小除(当 tiled=True 时),或者在 scatter_dimension 位置的维度被消除(当 tiled=False 时)。

例如,如果有4个可用的XLA设备:

>>> x = np.arange(16).reshape(4, 4)
>>> print(x)
[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]
 [12 13 14 15]]
>>> y = jax.pmap(lambda x: jax.lax.psum_scatter(x, 'i'), axis_name='i')(x)
>>> print(y)
[24 28 32 36]

如果使用 tiled:

>>> y = jax.pmap(lambda x: jax.lax.psum_scatter(x, 'i', tiled=True), axis_name='i')(x)
>>> print(y)
[[24]
 [28]
 [32]
 [36]]

使用 axis_index_groups 的示例:

>>> def f(x):
...   return jax.lax.psum_scatter(
...       x, 'i', axis_index_groups=[[0, 2], [3, 1]], tiled=True)
>>> y = jax.pmap(f, axis_name='i')(x)
>>> print(y)
[[ 8 10]
 [20 22]
 [12 14]
 [16 18]]