jax.experimental.mesh_utils 模块

目录

jax.experimental.mesh_utils 模块#

用于构建设备网格的工具。

API#

create_device_mesh(mesh_shape[, devices, ...])

为 jax.sharding.Mesh 创建一个高性能的设备网格。

create_hybrid_device_mesh(mesh_shape, ...[, ...])

为混合(例如,ICI 和 DCN)并行创建一个设备网格。