jax.lax.all_to_all#
- jax.lax.all_to_all(x, axis_name, split_axis, concat_axis, *, axis_index_groups=None, tiled=False)[源代码][源代码]#
实现映射的轴并映射不同的轴。
如果
x是一个 pytree,那么结果等同于将此函数映射到树中的每个叶子。在输出中,输入映射的轴
axis_name在逻辑轴位置concat_axis处被具体化,而输入未映射的轴在位置split_axis处被映射为名称axis_name。映射轴的组大小必须等于未映射轴的大小;也就是说,我们必须有
lax.psum(1, axis_name, axis_index_groups=axis_index_groups) == x.shape[axis]。默认情况下,当axis_index_groups=None时,这包括所有设备。- 参数:
x – 带有映射轴名为
axis_name的数组。axis_name – 用于命名 pmapped 轴的可哈希 Python 对象(更多详情请参阅
jax.pmap()文档)。split_axis – 表示未映射的
x轴,使用名称axis_name进行映射的整数。concat_axis – int 表示在输出中具体化输入的映射轴的位置,该轴的名称为
axis_name。axis_index_groups – 包含轴索引的可选列表(例如,对于大小为4的轴,[[0, 1], [2, 3]] 将对前两个和后两个副本进行 all_to_all 操作)。组必须完全覆盖所有轴索引一次,并且所有组的大小必须相同。
tiled – 当为 True 时,all_to_all 会将 split_axis 分割成块,并在 concat_axis 上将它们连接起来。特别地,不会增加或减少维度。默认值为 False。
- 返回:
当 tiled 为 False 时,形状由以下表达式给出的数组(或数组):: np.insert(np.delete(x.shape, split_axis), concat_axis, axis_size) 其中
axis_size是输入x中名为axis_name的映射轴的大小,即axis_size = lax.psum(1, axis_name)。否则,形状类似于输入形状的数组,除了 split_axis 被轴大小除,concat_axis 被轴大小乘。