开发者指南 / 使用 Keras 进行分布式训练 3

使用 Keras 进行分布式训练 3

作者: Qianli Zhu
创建日期: 2023/11/07
最后修改日期: 2023/11/07
描述: 多后端 Keras 的分布 API 完整指南。

在 Colab 中查看 GitHub 源代码


介绍

Keras 分布 API 是一个新的接口,旨在促进在各种后端如 JAX、TensorFlow 和 PyTorch 上的分布式深度学习。这个强大的 API 引入了一套工具,使数据和模型并行成为可能,从而在多个加速器和主机上有效扩展深度学习模型。无论是利用 GPU 还是 TPU 的强大能力,该 API 都提供了一种简化的方法来初始化分布式环境、定义设备网格以及协调张量在计算资源上的布局。通过像 DataParallelModelParallel 这样的类,它抽象了并行计算中涉及的复杂性,使开发者更容易加速他们的机器学习工作流程。


工作原理

Keras 分布 API 提供了一个全局编程模型,允许开发者在全局上下文中组合操作张量的应用程序(就像在单一设备上工作一样),同时自动管理多个设备之间的分布。该 API 利用底层框架(例如 JAX)根据分片指令通过称为单程序、多数据(SPMD)扩展的过程来分配程序和张量。

通过将应用程序与分片指令解耦,API 能够在单一设备、多个设备,甚至多个客户端上运行相同的应用程序,同时保持其全局语义。


设置

import os

# 目前分布 API 仅为 JAX 后端实现。
os.environ["KERAS_BACKEND"] = "jax"

import keras
from keras import layers
import jax
import numpy as np
from tensorflow import data as tf_data  # 用于数据集输入。

DeviceMeshTensorLayout

Keras 分布 API 中的 keras.distribution.DeviceMesh 类表示配置用于分布式计算的计算设备集群。它与 jax.sharding.Meshtf.dtensor.Mesh 中的类似概念相符, 在这些概念中,它用于将物理设备映射到逻辑网格结构。

TensorLayout 类指定了张量如何在 DeviceMesh 中分布,详细列出了沿着与 DeviceMesh 中轴的名称对应的指定轴的张量分片。

您可以在 TensorFlow DTensor 指南 中找到更详细的概念解释。

# 获取本地可用的 GPU 设备。
devices = jax.devices("gpu")  # 假设它有 8 个本地 GPU。

# 定义一个 2x4 的设备网格,具有数据和模型并行轴
mesh = keras.distribution.DeviceMesh(
    shape=(2, 4), axis_names=["data", "model"], devices=devices
)

# 一个 2D 布局,描述张量如何在网格上分布。
# 该布局可以可视化为一个 2D 网格,其中 "model" 为行,
# "data" 为列,在映射到网格上的物理设备时它是 [4, 2] 的网格。
layout_2d = keras.distribution.TensorLayout(axes=("model", "data"), device_mesh=mesh)

# 一个 4D 布局,可用于图像输入的数据并行。
replicated_layout_4d = keras.distribution.TensorLayout(
    axes=("data", None, None, None), device_mesh=mesh
)

分布

Keras 中的 Distribution 类作为一个基础抽象类,旨在开发自定义分布策略。它封装了将模型的变量、输入数据和中间计算在设备网格上分配所需的核心逻辑。作为最终用户,您无需直接与此类交互,但其子类如 DataParallelModelParallel


DataParallel

Keras 分布 API 中的 DataParallel 类旨在实现用于分布式训练的数据并行策略,在该策略中,模型权重在 DeviceMesh 中的所有设备上复制,每个设备处理一部分输入数据。

这是该类的一个示例用法。

# Create DataParallel with list of devices.
# As a shortcut, the devices can be skipped,
# and Keras will detect all local available devices.
# E.g. data_parallel = DataParallel()
data_parallel = keras.distribution.DataParallel(devices=devices)

# Or you can choose to create DataParallel with a 1D `DeviceMesh`.
mesh_1d = keras.distribution.DeviceMesh(
    shape=(8,), axis_names=["data"], devices=devices
)
data_parallel = keras.distribution.DataParallel(device_mesh=mesh_1d)

inputs = np.random.normal(size=(128, 28, 28, 1))
labels = np.random.normal(size=(128, 10))
dataset = tf_data.Dataset.from_tensor_slices((inputs, labels)).batch(16)

# Set the global distribution.
keras.distribution.set_distribution(data_parallel)

# Note that all the model weights from here on are replicated to
# all the devices of the `DeviceMesh`. This includes the RNG
# state, optimizer states, metrics, etc. The dataset fed into `model.fit` or
# `model.evaluate` will be split evenly on the batch dimension, and sent to
# all the devices. You don't have to do any manual aggregration of losses,
# since all the computation happens in a global context.
inputs = layers.Input(shape=(28, 28, 1))
y = layers.Flatten()(inputs)
y = layers.Dense(units=200, use_bias=False, activation="relu")(y)
y = layers.Dropout(0.4)(y)
y = layers.Dense(units=10, activation="softmax")(y)
model = keras.Model(inputs=inputs, outputs=y)

model.compile(loss="mse")
model.fit(dataset, epochs=3)
model.evaluate(dataset)
Epoch 1/3
 8/8 ━━━━━━━━━━━━━━━━━━━━ 8s 30ms/step - loss: 1.0116
Epoch 2/3
 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.9237
Epoch 3/3
 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.8736
 8/8 ━━━━━━━━━━━━━━━━━━━━ 5s 5ms/step - loss: 0.8349

0.842325747013092

ModelParallelLayoutMap

ModelParallel 在模型权重过大无法适应单个加速器时将尤其有用。此设置允许您将模型权重或激活张量分配到 DeviceMesh 上的所有设备上,并为大型模型启用横向扩展。

与所有权重完全复制的 DataParallel 模型不同,ModelParallel 下的权重布局通常需要一些定制以获得最佳性能。我们引入 LayoutMap 让您从全局角度指定任何权重和中间张量的 TensorLayout

LayoutMap 是一个类似字典的对象,将字符串映射到 TensorLayout 实例。在检索值时,字符串键被视为正则表达式,这使其与普通 Python 字典的行为不同。该类允许您定义 TensorLayout 的命名方案,然后检索相应的 TensorLayout 实例。通常,查询使用的键是 variable.path 属性,这个属性是变量的标识符。作为快捷方式,插入值时也允许使用元组或列表的轴名称,它将被转换为 TensorLayout

LayoutMap 也可以选择性地包含一个 DeviceMesh,以填充 TensorLayout.device_mesh,如果未设置此项。在通过键检索布局时,如果没有找到完全匹配的项,则布局映射中的所有现有键将被视为正则表达式,并再次与输入键匹配。如果存在多个匹配项,则会引发 ValueError。如果没有找到匹配项,则返回 None

mesh_2d = keras.distribution.DeviceMesh(
    shape=(2, 4), axis_names=["data", "model"], devices=devices
)
layout_map = keras.distribution.LayoutMap(mesh_2d)
# 下面的规则表示对于任何与 d1/kernel 匹配的权重,
# 它将与模型维度 (4 个设备) 一起被分割,d1/bias 同样如此。
# 所有其他权重将被完全复制。
layout_map["d1/kernel"] = (None, "model")
layout_map["d1/bias"] = ("model",)

# 您还可以设置层输出的布局,如下所示
layout_map["d2/output"] = ("data", None)

model_parallel = keras.distribution.ModelParallel(layout_map, batch_dim_name="data")

keras.distribution.set_distribution(model_parallel)

inputs = layers.Input(shape=(28, 28, 1))
y = layers.Flatten()(inputs)
y = layers.Dense(units=200, use_bias=False, activation="relu", name="d1")(y)
y = layers.Dropout(0.4)(y)
y = layers.Dense(units=10, activation="softmax", name="d2")(y)
model = keras.Model(inputs=inputs, outputs=y)

# 数据将在该方法的 "data" 维度上进行分割,该维度有 2 个设备。
model.compile(loss="mse")
model.fit(dataset, epochs=3)
model.evaluate(dataset)
Epoch 1/3

/opt/conda/envs/keras-jax/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:761: UserWarning: 一些捐赠的缓冲区无法使用: ShapedArray(float32[784,50]).
有关解释,请访问 https://jax.readthedocs.io/en/latest/faq.html#buffer-donation.
  warnings.warn("一些捐赠的缓冲区无法使用:"

 8/8 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - loss: 1.0266
Epoch 2/3
 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.9181
Epoch 3/3
 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.8725
 8/8 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - loss: 0.8381  

0.8502610325813293

更改网格结构以调整不同的数据并行或模型并行之间的计算也很简单。您可以通过调整网格的形状来实现此目的。并且不需要对任何其他代码进行更改。

full_data_parallel_mesh = keras.distribution.DeviceMesh(
    shape=(8, 1), axis_names=["data", "model"], devices=devices
)
more_data_parallel_mesh = keras.distribution.DeviceMesh(
    shape=(4, 2), axis_names=["data", "model"], devices=devices
)
more_model_parallel_mesh = keras.distribution.DeviceMesh(
    shape=(2, 4), axis_names=["data", "model"], devices=devices
)
full_model_parallel_mesh = keras.distribution.DeviceMesh(
    shape=(1, 8), axis_names=["data", "model"], devices=devices
)

深入阅读

  1. JAX 分布式数组和自动并行化
  2. JAX 分片模块
  3. 使用 DTensors 进行 TensorFlow 分布式训练
  4. TensorFlow DTensor 概念
  5. 使用 DTensors 与 tf.keras