jax.experimental.multihost_utils.broadcast_one_to_all#
- jax.experimental.multihost_utils.broadcast_one_to_all(in_tree, is_source=None)[源代码][源代码]#
将数据从源主机(默认主机0)广播到所有其他主机。
- 参数:
in_tree (Any) – 数组的pytree - 每个数组在主机之间 必须 具有相同的形状。
is_source (bool | None) – 可选的布尔值,表示调用者是否为源。只有 ‘源主机’ 会为广播贡献数据。如果为 None,则使用主机 0。
- 返回:
一个匹配 in_tree 的 pytree,其中所有叶子现在都包含来自第一个主机的数据。
- 返回类型:
Any