jax.lax.psum

目录

jax.lax.psum#

jax.lax.psum(x, axis_name, *, axis_index_groups=None)[源代码][源代码]#

x 上计算 axis_name 轴上的 pmapped 轴的 all-reduce 和。

如果 x 是一个 pytree,那么结果等同于将此函数映射到树中的每个叶子。

布尔类型的输入在归约之前会被转换为整数。

参数:
  • x – 带有映射轴名为 axis_name 的数组。

  • axis_name – 用于命名 pmapped 轴的可哈希 Python 对象(更多详情请参阅 jax.pmap() 文档)。

  • axis_index_groups – 包含轴索引的可选列表列表(例如,对于大小为4的轴,[[0, 1], [2, 3]] 将对前两个和最后两个副本执行psum操作)。组必须完全覆盖所有轴索引一次。

返回:

x 形状相同的数组,表示沿轴 axis_name 进行 all-reduce 求和的结果。

示例

例如,如果有4个可用的XLA设备:

>>> x = np.arange(4)
>>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(x)
>>> print(y)
[6 6 6 6]
>>> y = jax.pmap(lambda x: x / jax.lax.psum(x, 'i'), axis_name='i')(x)
>>> print(y)
[0.         0.16666667 0.33333334 0.5       ]

假设我们想在两个组之间执行 psum,一组包括 device0device1,另一组包括 device2device3

>>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i', axis_index_groups=[[0, 1], [2, 3]]), axis_name='i')(x)
>>> print(y)
[1 1 5 5]

使用二维形状 x 的示例。每一行是一个设备的数据。

>>> x = np.arange(16).reshape(4, 4)
>>> print(x)
[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]
 [12 13 14 15]]

所有设备上的完整 psum

>>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(x)
>>> print(y)
[[24 28 32 36]
 [24 28 32 36]
 [24 28 32 36]
 [24 28 32 36]]

在两个组之间执行 psum

>>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i', axis_index_groups=[[0, 1], [2, 3]]), axis_name='i')(x)
>>> print(y)
[[ 4  6  8 10]
 [ 4  6  8 10]
 [20 22 24 26]
 [20 22 24 26]]