jax.experimental.mesh_utils.create_hybrid_device_mesh#
- jax.experimental.mesh_utils.create_hybrid_device_mesh(mesh_shape, dcn_mesh_shape, devices=None, *, process_is_granule=False, should_sort_granules_by_key=True, allow_split_physical_axes=False)[源代码][源代码]#
为混合(例如,ICI 和 DCN)并行创建一个设备网格。
- 参数:
mesh_shape (Sequence[int]) – 更快/内部网络的逻辑网格形状,按网络强度递增顺序排列,例如 [replica, data, mdl],其中 mdl 具有最多的网络通信需求。
dcn_mesh_shape (Sequence[int]) – 较慢/外部网络的逻辑网格形状,顺序与 mesh_shape 相同。
devices (Sequence[Any] | None) – 可选地,用于构建网格的设备。默认为 jax.devices()。
process_is_granule (bool) – 如果为真,此函数将把进程视为较慢/外部网络的单位。否则,它将查找设备上的 slice_index 属性,并将切片作为单位。启用此功能意味着作为未设置 slice_index 的平台的一种回退。
should_sort_granules_by_key (bool) – 设备颗粒是否应根据颗粒键进行排序,无论是切片还是进程索引,取决于 process_is_granule。
allow_split_physical_axes (bool) – 如果为真,我们将根据需要拆分物理轴以生成所需的设备网格。
- 抛出:
ValueError – 如果 devices 所属的切片数量不等于 dcn_mesh_shape 的乘积,或者任何单个切片中的设备数量不等于 mesh_shape 的乘积。
- 返回:
一个形状为 mesh_shape * dcn_mesh_shape 的 JAX 设备 np.ndarray,可以输入到 jax.sharding.Mesh 中用于混合并行。
- 返回类型:
np.ndarray