Keras 3 API 文档 / 多设备分布 / 模型并行API

模型并行API

[source]

ModelParallel class

keras.distribution.ModelParallel(layout_map=None, batch_dim_name=None, **kwargs)

分布式模型变量分片.

与在所有设备上复制变量的 DataParallel 相比,ModelParallel 允许您在输入数据之外对变量进行分片.

要构造一个 ModelParallel 分布,您需要提供一个 DeviceMesh 和一个 LayoutMap.

  1. DeviceMesh 包含物理设备信息.网格中的轴名称将用于映射变量和数据布局.
  2. LayoutMap 包含变量路径与其对应的 TensorLayout 之间的映射.

示例:

devices = list_devices()    # 假设有 8 个设备.

# 创建一个网格,其中 2 个设备用于数据并行,4 个设备用于模型并行.
device_mesh = DeviceMesh(shape=(2, 4), axis_names=('batch', 'model'),
                         devices=devices)
# 创建一个布局映射,将 `Dense` 层和 `Conv2D` 层变量在最后一个维度上分片.
# 基于 `device_mesh`,这意味着变量将跨 4 个设备拆分.布局映射中未匹配任何键的任何其他变量将完全复制.
layout_map = LayoutMap(device_mesh)
layout_map['dense.*kernel'] = (None, 'model')
layout_map['dense.*bias'] = ('model',)
layout_map['conv2d.*kernel'] = (None, None, None, 'model')
layout_map['conv2d.*bias'] = ('model',)

distribution = ModelParallel(
    layout_map=layout_map,
    batch_dim_name='batch',
)

# 设置全局分布,或通过 `with distribution.scope():`
set_distribution(distribution)

model = model_creation()
model.compile()
model.fit(data)

您可以快速更新设备网格形状以更改变量的分片因子.例如:

# 仅通过设备网格的形状变化,变量将跨 8 个设备分片,而不是 4 个,这进一步减少了每个设备上变量的内存占用.
device_mesh = DeviceMesh(
    shape=(1, 8),
    axis_names=('batch', 'model'),
    devices=devices,
)

要为所有模型变量找出合适的布局映射规则,您可以首先列出所有模型变量路径,这些路径将用作将变量映射到 TensorLayout 的键.

例如:

model = create_model()
for v in model.variables:
    print(v.path)

参数: layout_map: LayoutMap 实例,将变量路径映射到相应的张量布局. batch_dim_name: 可选字符串,设备网格中的轴名称(layout_map 对象的),将用于分发数据.如果未指定,将使用设备网格的第一个轴.