jax.lax.all_to_all

目录

jax.lax.all_to_all#

jax.lax.all_to_all(x, axis_name, split_axis, concat_axis, *, axis_index_groups=None, tiled=False)[源代码][源代码]#

实现映射的轴并映射不同的轴。

如果 x 是一个 pytree,那么结果等同于将此函数映射到树中的每个叶子。

在输出中,输入映射的轴 axis_name 在逻辑轴位置 concat_axis 处被具体化,而输入未映射的轴在位置 split_axis 处被映射为名称 axis_name

映射轴的组大小必须等于未映射轴的大小;也就是说,我们必须有 lax.psum(1, axis_name, axis_index_groups=axis_index_groups) == x.shape[axis]。默认情况下,当 axis_index_groups=None 时,这包括所有设备。

参数:
  • x – 带有映射轴名为 axis_name 的数组。

  • axis_name – 用于命名 pmapped 轴的可哈希 Python 对象(更多详情请参阅 jax.pmap() 文档)。

  • split_axis – 表示未映射的 x 轴,使用名称 axis_name 进行映射的整数。

  • concat_axis – int 表示在输出中具体化输入的映射轴的位置,该轴的名称为 axis_name

  • axis_index_groups – 包含轴索引的可选列表(例如,对于大小为4的轴,[[0, 1], [2, 3]] 将对前两个和后两个副本进行 all_to_all 操作)。组必须完全覆盖所有轴索引一次,并且所有组的大小必须相同。

  • tiled – 当为 True 时,all_to_all 会将 split_axis 分割成块,并在 concat_axis 上将它们连接起来。特别地,不会增加或减少维度。默认值为 False。

返回:

当 tiled 为 False 时,形状由以下表达式给出的数组(或数组):: np.insert(np.delete(x.shape, split_axis), concat_axis, axis_size) 其中 axis_size 是输入 x 中名为 axis_name 的映射轴的大小,即 axis_size = lax.psum(1, axis_name)。否则,形状类似于输入形状的数组,除了 split_axis 被轴大小除,concat_axis 被轴大小乘。