jax.lax.all_gather#
- jax.lax.all_gather(x, axis_name, *, axis_index_groups=None, axis=0, tiled=False)[源代码][源代码]#
收集所有副本中的 x 值。
如果
x
是一个 pytree,那么结果等同于将此函数映射到树中的每个叶子。这相当于,但比 all_to_all(broadcast(x)) 更快。
- 参数:
x – 带有映射轴名为
axis_name
的数组。axis_name – 用于命名 pmapped 轴的可哈希 Python 对象(更多详情请参阅
jax.pmap()
文档)。axis_index_groups – 包含轴索引的可选列表(例如,对于大小为4的轴,[[0, 1], [2, 3]] 将对前两个和后两个副本运行所有收集操作)。组必须准确覆盖所有轴索引一次,并且所有组的大小必须相同。
axis – 一个位置轴,沿着
axis_name
的块将被连接到这个轴上。tiled – 当
False
时,块将被堆叠到输出中索引为axis
的新位置轴上。当True
时,axis
必须引用一个现有的位置维度,并且块将被连接到该维度中。
- 返回:
表示沿轴
axis_name
进行 all-gather 操作结果的数组。形状与x.shape
相同,但: - 当tiled
为False
时,在位置axis
处有一个新维度,其大小等于轴axis_name
的大小, - 当tiled
为True
时,位置axis
处的维度大小乘以轴axis_name
的大小。
例如,如果有4个可用的XLA设备:
>>> x = np.arange(4) >>> y = jax.pmap(lambda x: jax.lax.all_gather(x, 'i'), axis_name='i')(x) >>> print(y) [[0 1 2 3] [0 1 2 3] [0 1 2 3] [0 1 2 3]]
使用 axis_index_groups 的示例,按偶数和奇数设备 ID 分组:
>>> x = np.arange(16).reshape(4, 4) >>> print(x) [[ 0 1 2 3] [ 4 5 6 7] [ 8 9 10 11] [12 13 14 15]] >>> def f(x): ... return jax.lax.all_gather( ... x, 'i', axis_index_groups=[[0, 2], [3, 1]]) >>> y = jax.pmap(f, axis_name='i')(x) >>> print(y) [[[ 0 1 2 3] [ 8 9 10 11]] [[12 13 14 15] [ 4 5 6 7]] [[ 0 1 2 3] [ 8 9 10 11]] [[12 13 14 15] [ 4 5 6 7]]]