作者: fchollet
创建日期: 2020/04/28
最后修改: 2023/06/29
描述: 使用TensorFlow为Keras模型进行多GPU训练的指南。
通常有两种方法可以在多个设备上分布计算:
数据并行,其中单个模型在多个设备或多个机器上复制。每个设备处理不同的数据批次,然后它们合并结果。这种设置有许多变体,不同之处在于模型副本如何合并结果,它们是每批次保持同步还是更松散地耦合,等等。
模型并行,其中单个模型的不同部分在不同设备上运行,共同处理单个数据批次。这对于具有自然并行架构的模型效果最好,例如具有多个分支的模型。
本指南重点介绍数据并行,特别是同步数据并行,其中模型的不同副本在处理每个批次后保持同步。同步使模型的收敛行为与单设备训练时相同。
具体来说,本指南将教你如何使用 tf.distribute
API 在多个 GPU 上训练 Keras 模型,只需对代码进行最小的更改,在单个机器(单主机,多设备训练)上安装多个 GPU(通常是 2 到 16 个)。这是研究人员和小规模行业工作流中最常见的设置。
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import tensorflow as tf
import keras
在这种设置中,你有一台机器,上面有多个GPU(通常是2到16个)。每个设备将运行你的模型的副本(称为副本)。为了简单起见,在接下来的内容中,我们将假设我们正在处理8个GPU,但这并不失一般性。
工作原理
在训练的每一步:
在实践中,同步更新模型副本权重的过程是在每个单独的权重变量级别处理的。这是通过一个镜像变量对象完成的。
如何使用
要使用Keras模型进行单主机、多设备同步训练,你可以使用tf.distribute.MirroredStrategy
API。以下是它的工作原理:
MirroredStrategy
,可选地配置你想要使用的特定设备(默认情况下,策略将使用所有可用的GPU)。fit()
的第一次调用也可能创建变量,所以最好也将你的fit()
调用放在作用域内。fit()
像往常一样训练模型。重要的是,我们建议你使用tf.data.Dataset
对象在多设备或分布式工作流中加载数据。
大致上,它看起来像这样:
# 创建一个MirroredStrategy。
strategy = tf.distribute.MirroredStrategy()
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
# 打开一个策略作用域。
with strategy.scope():
# 所有创建变量的操作都应该在策略作用域内。
# 通常这只是模型构造和`compile()`。
model = Model(...)
model.compile(...)
# 在所有可用设备上训练模型。
model.fit(train_dataset, validation_data=val_dataset, ...)
# 在所有可用设备上测试模型。
model.evaluate(test_dataset)
这里有一个简单的端到端可运行示例:
def get_compiled_model():
# 创建一个简单的2层全连接神经网络。
inputs = keras.Input(shape=(784,))
x = keras.layers.Dense(256, activation="relu")(inputs)
x = keras.layers.Dense(256, activation="relu")(x)
outputs = keras.layers.Dense(10)(x)
model = keras.Model(inputs, outputs)
model.compile(
optimizer=keras.optimizers.Adam(),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
return model
def get_dataset():
batch_size = 32
num_val_samples = 10000
# 返回以[`tf.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset)形式表示的MNIST数据集。
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
# 预处理数据(这些是Numpy数组)
x_train = x_train.reshape(-1, 784).astype("float32") / 255
x_test = x_test.reshape(-1, 784).astype("float32") / 255
y_train = y_train.astype("float32")
y_test = y_test.astype("float32")
# 保留num_val_samples个样本用于验证
x_val = x_train[-num_val_samples:]
y_val = y_train[-num_val_samples:]
x_train = x_train[:-num_val_samples]
y_train = y_train[:-num_val_samples]
return (
tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(batch_size),
tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(batch_size),
tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(batch_size),
)
# 创建一个MirroredStrategy。
strategy = tf.distribute.MirroredStrategy()
print("Number of devices: {}".format(strategy.num_replicas_in_sync))
# 打开一个策略范围。
with strategy.scope():
# 所有创建变量的操作都应该在策略范围内。
# 通常这只是模型构建和`compile()`。
model = get_compiled_model()
# 在所有可用设备上训练模型。
train_dataset, val_dataset, test_dataset = get_dataset()
model.fit(train_dataset, epochs=2, validation_data=val_dataset)
# 在所有可用设备上测试模型。
model.evaluate(test_dataset)
INFO:tensorflow:使用 MirroredStrategy 和设备 ('/job:localhost/replica:0/task:0/device:CPU:0',)
设备数量: 1
第1/2轮
1563/1563 ━━━━━━━━━━━━━━━━━━━━ 7s 4ms/step - loss: 0.3830 - sparse_categorical_accuracy: 0.8884 - val_loss: 0.1361 - val_sparse_categorical_accuracy: 0.9574
第2/2轮
1563/1563 ━━━━━━━━━━━━━━━━━━━━ 9s 3ms/step - loss: 0.1068 - sparse_categorical_accuracy: 0.9671 - val_loss: 0.0894 - val_sparse_categorical_accuracy: 0.9724
313/313 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - loss: 0.0988 - sparse_categorical_accuracy: 0.9673
在使用分布式训练时,你应该始终确保有一种策略可以从故障中恢复(容错性)。处理这个问题的最简单方法是传递 ModelCheckpoint
回调给 fit()
,以定期保存你的模型(例如每100个批次或每个周期)。然后你可以从保存的模型中重新开始训练。
这里有一个简单的例子:
# 准备一个目录来存储所有的检查点。
checkpoint_dir = "./ckpt"
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
def make_or_restore_model():
# 要么恢复最新的模型,要么如果没有检查点可用则创建一个新的模型。
checkpoints = [checkpoint_dir + "/" + name for name in os.listdir(checkpoint_dir)]
if checkpoints:
latest_checkpoint = max(checkpoints, key=os.path.getctime)
print("从", latest_checkpoint, "恢复")
return keras.models.load_model(latest_checkpoint)
print("创建一个新的模型")
return get_compiled_model()
def run_training(epochs=1):
# 创建一个 MirroredStrategy。
strategy = tf.distribute.MirroredStrategy()
# 打开一个策略作用域并创建/恢复模型
with strategy.scope():
model = make_or_restore_model()
callbacks = [
# 这个回调每周期保存一个 SavedModel
# 我们在文件夹名称中包含当前周期。
keras.callbacks.ModelCheckpoint(
filepath=checkpoint_dir + "/ckpt-{epoch}.keras",
save_freq="epoch",
)
]
model.fit(
train_dataset,
epochs=epochs,
callbacks=callbacks,
validation_data=val_dataset,
verbose=2,
)
# 第一次运行时创建模型
run_training(epochs=1)
# 再次调用相同函数将从上次停止的地方继续
run_training(epochs=1)
信息:tensorflow:使用MirroredStrategy与设备('/job:localhost/replica:0/task:0/device:CPU:0',)
创建一个新模型
1563/1563 - 7s - 4ms/step - loss: 0.2275 - sparse_categorical_accuracy: 0.9320 - val_loss: 0.1373 - val_sparse_categorical_accuracy: 0.9571
信息:tensorflow:使用MirroredStrategy与设备('/job:localhost/replica:0/task:0/device:CPU:0',)
从./ckpt/ckpt-1.keras恢复
1563/1563 - 6s - 4ms/step - loss: 0.0944 - sparse_categorical_accuracy: 0.9717 - val_loss: 0.0972 - val_sparse_categorical_accuracy: 0.9710
tf.data
性能提示在进行分布式训练时,数据加载的效率往往变得至关重要。以下是一些确保你的 tf.data
管道运行尽可能快的提示。
关于数据集批处理的注意事项
在创建数据集时,确保它是以全局批量大小进行批处理的。例如,如果你的8个GPU每个都能运行一批64个样本,你可以使用全局批量大小为512。
调用 dataset.cache()
如果你在一个数据集上调用 .cache()
,它的数据将在第一次遍历数据后被缓存。每次后续迭代都将使用缓存的数据。缓存可以在内存中(默认)或指定到本地文件。
这可以在以下情况下提高性能:
调用 dataset.prefetch(buffer_size)
在创建数据集后,你几乎总是应该调用 .prefetch(buffer_size)
。这意味着你的数据管道将与模型异步运行,在当前批次样本用于训练模型时,新样本将被预处理并存储在缓冲区中。当前批次结束时,下一批次将被预取到GPU内存中。
就是这样!