ModelParallel
classkeras.distribution.ModelParallel(layout_map=None, batch_dim_name=None, **kwargs)
分布式模型变量分片.
与在所有设备上复制变量的 DataParallel
相比,ModelParallel
允许您在输入数据之外对变量进行分片.
要构造一个 ModelParallel
分布,您需要提供一个 DeviceMesh
和一个 LayoutMap
.
DeviceMesh
包含物理设备信息.网格中的轴名称将用于映射变量和数据布局.LayoutMap
包含变量路径与其对应的 TensorLayout
之间的映射.示例:
devices = list_devices() # 假设有 8 个设备.
# 创建一个网格,其中 2 个设备用于数据并行,4 个设备用于模型并行.
device_mesh = DeviceMesh(shape=(2, 4), axis_names=('batch', 'model'),
devices=devices)
# 创建一个布局映射,将 `Dense` 层和 `Conv2D` 层变量在最后一个维度上分片.
# 基于 `device_mesh`,这意味着变量将跨 4 个设备拆分.布局映射中未匹配任何键的任何其他变量将完全复制.
layout_map = LayoutMap(device_mesh)
layout_map['dense.*kernel'] = (None, 'model')
layout_map['dense.*bias'] = ('model',)
layout_map['conv2d.*kernel'] = (None, None, None, 'model')
layout_map['conv2d.*bias'] = ('model',)
distribution = ModelParallel(
layout_map=layout_map,
batch_dim_name='batch',
)
# 设置全局分布,或通过 `with distribution.scope():`
set_distribution(distribution)
model = model_creation()
model.compile()
model.fit(data)
您可以快速更新设备网格形状以更改变量的分片因子.例如:
# 仅通过设备网格的形状变化,变量将跨 8 个设备分片,而不是 4 个,这进一步减少了每个设备上变量的内存占用.
device_mesh = DeviceMesh(
shape=(1, 8),
axis_names=('batch', 'model'),
devices=devices,
)
要为所有模型变量找出合适的布局映射规则,您可以首先列出所有模型变量路径,这些路径将用作将变量映射到 TensorLayout
的键.
例如:
model = create_model()
for v in model.variables:
print(v.path)
参数:
layout_map: LayoutMap
实例,将变量路径映射到相应的张量布局.
batch_dim_name: 可选字符串,设备网格中的轴名称(layout_map
对象的),将用于分发数据.如果未指定,将使用设备网格的第一个轴.