作者: nkovela1
创建日期: 2022/09/19
最后修改日期: 2022/09/26
描述: 指导如何在多个 Keras 模型之间共享自定义训练步骤。
这个例子展示了如何使用“Trainer 模式”创建一个自定义训练步骤,并可以在多个Keras模型中共享。该模式重写了 keras.Model
类的 train_step()
方法,允许进行超出简单监督学习的训练循环。
Trainer 模式还可以很容易地适应具有更复杂模型和更大自定义训练步骤,例如 这个端到端的 GAN 模型,通过在 Trainer 类定义中放置自定义训练步骤。
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import tensorflow as tf
import keras
# 加载 MNIST 数据集并标准化数据
mnist = keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
可以通过重写 Model
子类的 train_step()
和 test_step()
方法来创建自定义训练和评估步骤:
class MyTrainer(keras.Model):
def __init__(self, model):
super().__init__()
self.model = model
# 在这里创建损失和指标。
self.loss_fn = keras.losses.SparseCategoricalCrossentropy()
self.accuracy_metric = keras.metrics.SparseCategoricalAccuracy()
@property
def metrics(self):
# 在这里列出指标。
return [self.accuracy_metric]
def train_step(self, data):
x, y = data
with tf.GradientTape() as tape:
y_pred = self.model(x, training=True) # 向前传递
# 计算损失值
loss = self.loss_fn(y, y_pred)
# 计算梯度
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
# 更新权重
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
# 更新指标
for metric in self.metrics:
metric.update_state(y, y_pred)
# 返回一个字典,将指标名称映射到当前值。
return {m.name: m.result() for m in self.metrics}
def test_step(self, data):
x, y = data
# 推断步骤
y_pred = self.model(x, training=False)
# 更新指标
for metric in self.metrics:
metric.update_state(y, y_pred)
return {m.name: m.result() for m in self.metrics}
def call(self, x):
# 相当于被包装的 keras.Model 的 `call()`
x = self.model(x)
return x
让我们定义两个可以共享我们的 Trainer 类及其自定义 train_step()
的不同模型:
# 使用顺序 API 定义的模型
model_a = keras.models.Sequential(
[
keras.layers.Flatten(input_shape=(28, 28)),
keras.layers.Dense(256, activation="relu"),
keras.layers.Dropout(0.2),
keras.layers.Dense(10, activation="softmax"),
]
)
# 使用函数式 API 定义的模型
func_input = keras.Input(shape=(28, 28, 1))
x = keras.layers.Flatten(input_shape=(28, 28))(func_input)
x = keras.layers.Dense(512, activation="relu")(x)
x = keras.layers.Dropout(0.4)(x)
func_output = keras.layers.Dense(10, activation="softmax")(x)
model_b = keras.Model(func_input, func_output)
trainer_1 = MyTrainer(model_a)
trainer_2 = MyTrainer(model_b)
trainer_1.compile(optimizer=keras.optimizers.SGD())
trainer_1.fit(
x_train, y_train, epochs=5, batch_size=64, validation_data=(x_test, y_test)
)
trainer_2.compile(optimizer=keras.optimizers.Adam())
trainer_2.fit(
x_train, y_train, epochs=5, batch_size=64, validation_data=(x_test, y_test)
)
第 1/5 纪元
...
第 4/5 纪元
938/938 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/步 - sparse_categorical_accuracy: 0.9770 - val_sparse_categorical_accuracy: 0.9770
第 5/5 纪元
938/938 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/步 - sparse_categorical_accuracy: 0.9805 - val_sparse_categorical_accuracy: 0.9789
<keras.src.callbacks.history.History at 0x7efe405fe560>