开发者指南 / 使用 JAX 进行多 GPU 分布式训练

使用 JAX 进行多 GPU 分布式训练

作者: fchollet
创建日期: 2023/07/11
最后修改日期: 2023/07/11
描述: 使用 JAX 对 Keras 模型进行多 GPU/TPU 训练的指南。

在 Colab 中查看 GitHub 源


介绍

一般来说,有两种方法可以将计算分布在多台设备上:

数据并行性,其中单个模型在多个设备或多个机器上进行复制。每个设备处理不同的数据批次,然后合并它们的结果。这个设置有许多变种,它们在如何合并不同模型副本的结果、每个批次的同步性、它们是否更松散耦合等方面有所不同。

模型并行性,其中单个模型的不同部分在不同设备上运行,共同处理单个数据批次。这种方法最适合具有自然并行架构的模型,例如那些具有多个分支的模型。

本指南主要关注数据并行性,特别是同步数据并行性,其中模型的不同副本在处理每个批次后保持同步。同步性使得模型的收敛行为与单设备训练的行为一致。

具体而言,本指南教你如何使用 jax.sharding API 以最小的代码改动在多个 GPU 或 TPU(通常为 2 到 16)上训练 Keras 模型。这是在单台机器上(单主机,多设备训练)最常见的设置,适合研究人员和小规模行业工作流。


设置

让我们先定义一个创建模型的函数,以及一个创建我们将要训练的数据集的函数(此处为 MNIST)。

import os

os.environ["KERAS_BACKEND"] = "jax"

import jax
import numpy as np
import tensorflow as tf
import keras

from jax.experimental import mesh_utils
from jax.sharding import Mesh
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P


def get_model():
    # 创建一个简单的卷积神经网络,带有批量归一化和 dropout。
    inputs = keras.Input(shape=(28, 28, 1))
    x = keras.layers.Rescaling(1.0 / 255.0)(inputs)
    x = keras.layers.Conv2D(filters=12, kernel_size=3, padding="same", use_bias=False)(
        x
    )
    x = keras.layers.BatchNormalization(scale=False, center=True)(x)
    x = keras.layers.ReLU()(x)
    x = keras.layers.Conv2D(
        filters=24,
        kernel_size=6,
        use_bias=False,
        strides=2,
    )(x)
    x = keras.layers.BatchNormalization(scale=False, center=True)(x)
    x = keras.layers.ReLU()(x)
    x = keras.layers.Conv2D(
        filters=32,
        kernel_size=6,
        padding="same",
        strides=2,
        name="large_k",
    )(x)
    x = keras.layers.BatchNormalization(scale=False, center=True)(x)
    x = keras.layers.ReLU()(x)
    x = keras.layers.GlobalAveragePooling2D()(x)
    x = keras.layers.Dense(256, activation="relu")(x)
    x = keras.layers.Dropout(0.5)(x)
    outputs = keras.layers.Dense(10)(x)
    model = keras.Model(inputs, outputs)
    return model


def get_datasets():
    # 加载数据并将其分为训练集和测试集
    (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

    # 将图像缩放到 [0, 1] 范围
    x_train = x_train.astype("float32")
    x_test = x_test.astype("float32")
    # 确保图像形状为 (28, 28, 1)
    x_train = np.expand_dims(x_train, -1)
    x_test = np.expand_dims(x_test, -1)
    print("x_train shape:", x_train.shape)
    print(x_train.shape[0], "训练样本")
    print(x_test.shape[0], "测试样本")

    # 创建 TF 数据集
    train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train))
    eval_data = tf.data.Dataset.from_tensor_slices((x_test, y_test))
    return train_data, eval_data

单主机,多设备同步训练

在这个设置中,您有一台机器,里面有多张 GPU 或 TPU(通常为 2 到 16)。每个设备将运行您模型的一个副本(称为副本)。为简单起见,接下来我们假设我们正在处理 8 个 GPU,而不失一般性。

它是如何工作的

在训练的每一个步骤中:

  • 当前的数据批次(称为全局批次)被分割成 8 个不同的子批次(称为局部批次)。例如,如果全局批次有 512 个样本,则每个 8 个局部批次将有 64 个样本。
  • 每个 8 个副本独立处理一个局部批次:它们进行一次前向传递, 然后进行向后传播,输出与模型在本地批次上的损失相关的权重梯度。
  • 来自局部梯度的权重更新在8个副本之间有效地合并。由于这是在每个步骤的结束时进行的,因此副本始终保持同步。

在实践中,同步更新模型副本的权重的过程是在每个单独的权重变量的层面上处理的。这是通过使用配置为复制变量的 jax.sharding.NamedSharding 来完成的。

如何使用它

要使用 Keras 模型进行单主机、多设备的同步训练,你需要使用 jax.sharding 特性。以下是其工作原理:

  • 我们首先使用 mesh_utils.create_device_mesh 创建一个设备网格。
  • 我们使用 jax.sharding.Meshjax.sharding.NamedShardingjax.sharding.PartitionSpec 来定义如何划分 JAX 数组。 - 我们通过使用没有轴的规范指定希望在所有设备上复制模型和优化器变量。 - 我们通过使用沿批次维度拆分的规范指定希望在设备上分片数据。
  • 我们使用 jax.device_put 在设备之间复制模型和优化器变量。这一步在开始时进行一次。
  • 在训练循环中,对于我们处理的每个批次,我们使用 jax.device_put 在调用训练步骤之前将批次拆分到设备上。

以下是流程,其中每个步骤都被分成自己的工具函数:

# 配置
num_epochs = 2
batch_size = 64

train_data, eval_data = get_datasets()
train_data = train_data.batch(batch_size, drop_remainder=True)

model = get_model()
optimizer = keras.optimizers.Adam(1e-3)
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

# 使用 .build() 初始化所有状态
(one_batch, one_batch_labels) = next(iter(train_data))
model.build(one_batch)
optimizer.build(model.trainable_variables)


# 这是将被微分的损失函数。
# Keras 提供了纯函数的前向传递:model.stateless_call
def compute_loss(trainable_variables, non_trainable_variables, x, y):
    y_pred, updated_non_trainable_variables = model.stateless_call(
        trainable_variables, non_trainable_variables, x
    )
    loss_value = loss(y, y_pred)
    return loss_value, updated_non_trainable_variables


# 计算梯度的函数
compute_gradients = jax.value_and_grad(compute_loss, has_aux=True)


# 训练步骤,Keras 提供了纯函数的优化器.stateless_apply
@jax.jit
def train_step(train_state, x, y):
    trainable_variables, non_trainable_variables, optimizer_variables = train_state
    (loss_value, non_trainable_variables), grads = compute_gradients(
        trainable_variables, non_trainable_variables, x, y
    )

    trainable_variables, optimizer_variables = optimizer.stateless_apply(
        optimizer_variables, grads, trainable_variables
    )

    return loss_value, (
        trainable_variables,
        non_trainable_variables,
        optimizer_variables,
    )


# 在所有设备上复制模型和优化器变量
def get_replicated_train_state(devices):
    # 所有变量将在所有设备上复制
    var_mesh = Mesh(devices, axis_names=("_"))
    # 在 NamedSharding 中,没有提到的轴会被复制(这里所有轴)
    var_replication = NamedSharding(var_mesh, P())

    # 将分发设置应用于模型变量
    trainable_variables = jax.device_put(model.trainable_variables, var_replication)
    non_trainable_variables = jax.device_put(
        model.non_trainable_variables, var_replication
    )
    optimizer_variables = jax.device_put(optimizer.variables, var_replication)

    # 将所有状态组合成一个元组
    return (trainable_variables, non_trainable_variables, optimizer_variables)


num_devices = len(jax.local_devices())
print(f"在 {num_devices} 个设备上运行: {jax.local_devices()}")
devices = mesh_utils.create_device_mesh((num_devices,))

# 数据将在批次轴上进行拆分
data_mesh = Mesh(devices, axis_names=("batch",))  # 网格的轴命名
data_sharding = NamedSharding(
    data_mesh,
    P(
        "batch",
    ),
)  # 分片分区的轴命名

# 显示数据分片
x, y = next(iter(train_data))
sharded_x = jax.device_put(x.numpy(), data_sharding)
print("数据分片")
jax.debug.visualize_array_sharding(jax.numpy.reshape(sharded_x, [-1, 28 * 28]))

train_state = get_replicated_train_state(devices)

# 自定义训练循环
for epoch in range(num_epochs):
    data_iter = iter(train_data)
    for data in data_iter:
        x, y = data
        sharded_x = jax.device_put(x.numpy(), data_sharding)
        loss_value, train_state = train_step(train_state, sharded_x, y.numpy())
    print("第", epoch, "轮,损失:", loss_value)

# 后处理模型状态更新以将其写回模型
trainable_variables, non_trainable_variables, optimizer_variables = train_state
for variable, value in zip(model.trainable_variables, trainable_variables):
    variable.assign(value)
for variable, value in zip(model.non_trainable_variables, non_trainable_variables):
    variable.assign(value)
x_train 形状: (60000, 28, 28, 1)
60000 训练样本
10000 测试样本
运行在 1 个设备上: [CpuDevice(id=0)]
数据分片
                                                                                
                                                                                
                                                                                
                                                                                
                                                                                
                                     CPU 0                                      
                                                                                
                                                                                
                                                                                
                                                                                
                                                                                
第 0 轮损失: 0.43531758
第 1 轮损失: 0.5194763

That's it!