jax.experimental.multihost_utils.process_allgather#
- jax.experimental.multihost_utils.process_allgather(in_tree, tiled=False)[源代码][源代码]#
从各个进程中收集数据。
- 参数:
in_tree (Any) – 数组的pytree - 每个数组在主机之间 _必须_ 具有相同的形状。
tiled (bool) – 是否堆叠或连接输出。默认为 False,即在索引 0 处堆叠到新的位置轴。
- 返回:
numpy 数组的 Pytrees。 * 如果输入是一个非完全可寻址的 jax.Array,那么数据将被完全复制。 * 如果输入是 numpy 数组或完全可寻址的 jax.Array,那么输出形状取决于 tiled 参数。 如果其为 False,那么输出将被堆叠,否则将被连接。 * 如果输入是一个标量,那么输出将被堆叠。
- 返回类型:
Any