jax.lax.psum#
- jax.lax.psum(x, axis_name, *, axis_index_groups=None)[源代码][源代码]#
在
x
上计算axis_name
轴上的 pmapped 轴的 all-reduce 和。如果
x
是一个 pytree,那么结果等同于将此函数映射到树中的每个叶子。布尔类型的输入在归约之前会被转换为整数。
- 参数:
x – 带有映射轴名为
axis_name
的数组。axis_name – 用于命名 pmapped 轴的可哈希 Python 对象(更多详情请参阅
jax.pmap()
文档)。axis_index_groups – 包含轴索引的可选列表列表(例如,对于大小为4的轴,[[0, 1], [2, 3]] 将对前两个和最后两个副本执行psum操作)。组必须完全覆盖所有轴索引一次。
- 返回:
与
x
形状相同的数组,表示沿轴axis_name
进行 all-reduce 求和的结果。
示例
例如,如果有4个可用的XLA设备:
>>> x = np.arange(4) >>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(x) >>> print(y) [6 6 6 6] >>> y = jax.pmap(lambda x: x / jax.lax.psum(x, 'i'), axis_name='i')(x) >>> print(y) [0. 0.16666667 0.33333334 0.5 ]
假设我们想在两个组之间执行
psum
,一组包括device0
和device1
,另一组包括device2
和device3
,>>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i', axis_index_groups=[[0, 1], [2, 3]]), axis_name='i')(x) >>> print(y) [1 1 5 5]
使用二维形状 x 的示例。每一行是一个设备的数据。
>>> x = np.arange(16).reshape(4, 4) >>> print(x) [[ 0 1 2 3] [ 4 5 6 7] [ 8 9 10 11] [12 13 14 15]]
所有设备上的完整
psum
:>>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(x) >>> print(y) [[24 28 32 36] [24 28 32 36] [24 28 32 36] [24 28 32 36]]
在两个组之间执行
psum
:>>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i', axis_index_groups=[[0, 1], [2, 3]]), axis_name='i')(x) >>> print(y) [[ 4 6 8 10] [ 4 6 8 10] [20 22 24 26] [20 22 24 26]]