开发者指南 / 使用 TensorFlow 自定义 `fit()` 中的行为

使用 TensorFlow 自定义 fit() 中的行为

作者: fchollet
创建日期: 2020/04/15
最后修改: 2023/06/27
描述: 使用 TensorFlow 重新定义 Model 类的训练步骤。

在 Colab 中查看 GitHub 源代码


介绍

当你进行监督学习时,可以使用 fit() 一切都能顺利进行。

当你需要控制每一个细节时,可以完全从头编写自己的训练循环。

但是如果你需要一个自定义的训练算法,但仍然想利用 fit() 的便捷功能,比如回调、内置分布支持或步骤融合,该怎么办呢?

Keras 的一个核心原则是 逐步揭示复杂性。你应该能够以渐进的方式进入更低级的工作流程。如果高级功能与用例不完全匹配,你不应该掉入深渊。你应该能够在保持适当的高级便利性的同时,获得对小细节的更大控制。

当你需要自定义 fit() 的行为时,你应该 重写 Model 类的训练步骤函数。这是 fit() 为每个数据批次调用的函数。然后你将能够像往常一样调用 fit() —— 它将运行你自己的学习算法。

请注意,这种模式并不会阻止你使用功能性 API 构建模型。无论你是在构建 Sequential 模型、功能性 API 模型还是子类模型,你都可以这样做。

让我们看看这是如何工作的。


设置

import os

# 此指南只能在 TF 后端运行。
os.environ["KERAS_BACKEND"] = "tensorflow"

import tensorflow as tf
import keras
from keras import layers
import numpy as np

第一个简单示例

让我们从一个简单的例子开始:

  • 我们创建一个子类 keras.Model 的新类。
  • 我们只是重写方法 train_step(self, data)
  • 我们返回一个将指标名称(包括损失)映射到其当前值的字典。

输入参数 data 是传递给 fit 作为训练数据的内容:

  • 如果你通过调用 fit(x, y, ...) 传递 NumPy 数组,则 data 将是元组 (x, y)
  • 如果你通过调用 fit(dataset, ...) 传递 tf.data.Dataset,那么 data 将是每个批次从 dataset 生成的内容。

train_step() 方法的主体中,我们实现了一个常规的训练更新,类似于你已经熟悉的内容。重要的是,我们通过 self.compute_loss() 计算损失,它封装了传递给 compile() 的损失函数。

同样,我们在 self.metrics 的指标上调用 metric.update_state(y, y_pred),以更新在 compile() 中传递的指标的状态,并且在最后查询 self.metrics 的结果以检索它们的当前值。

class CustomModel(keras.Model):
    def train_step(self, data):
        # 解包数据。它的结构取决于你的模型和
        # 你传递给 `fit()` 的内容。
        x, y = data

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)  # 前向传播
            # 计算损失值
            # (损失函数在 `compile()` 中配置)
            loss = self.compute_loss(y=y, y_pred=y_pred)

        # 计算梯度
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # 更新权重
        self.optimizer.apply(gradients, trainable_vars)

        # 更新指标(包括跟踪损失的指标)
        for metric in self.metrics:
            if metric.name == "loss":
                metric.update_state(loss)
            else:
                metric.update_state(y, y_pred)

        # 返回一个将指标名称映射到当前值的字典
        return {m.name: m.result() for m in self.metrics}

让我们试试这个:

# 构建并编译 CustomModel 的一个实例
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(optimizer="adam", loss="mse", metrics=["mae"])

# 像往常一样使用 `fit`
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.fit(x, y, epochs=3)
Epoch 1/3
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - mae: 0.5089 - loss: 0.3778   
Epoch 2/3
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 318us/step - mae: 0.3986 - loss: 0.2466
Epoch 3/3
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 372us/step - mae: 0.3848 - loss: 0.2319

WARNING: 所有在 absl::InitializeLog() 被调用之前的日志消息都被写入 STDERR
I0000 00:00:1699222602.443035       1 device_compiler.h:187] 编译集群使用 XLA! 该行在进程生命周期内最多记录一次。

<keras.src.callbacks.history.History at 0x2a5599f00>

更低层级的实现

自然,你可以选择跳过在 compile() 中传递损失函数,转而在 train_step手动完成所有操作。度量标准也是如此。

以下是一个更低级的示例,仅使用 compile() 配置优化器:

  • 我们开始创建 Metric 实例来跟踪我们的损失和 MAE 分数(在 __init__() 中)。
  • 我们实现一个自定义的 train_step(),更新这些度量的状态(通过调用 update_state()),然后通过 result() 查询它们以返回当前平均值,显示在进度条上并传递给任何回调。
  • 请注意,我们需要在每个 epoch 之间调用 reset_states()!否则调用 result() 将返回自训练开始以来的平均值,而我们通常处理的是每个 epoch 的平均值。幸运的是,框架可以为我们完成此操作:只需在模型的 metrics 属性中列出任何要重置的度量。模型将在每次 fit() epoch 开始时或调用 evaluate() 开始时调用列出对象的 reset_states()
class CustomModel(keras.Model):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.loss_tracker = keras.metrics.Mean(name="loss")
        self.mae_metric = keras.metrics.MeanAbsoluteError(name="mae")
        self.loss_fn = keras.losses.MeanSquaredError()

    def train_step(self, data):
        x, y = data

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)  # 前向传播
            # 计算我们自己的损失
            loss = self.loss_fn(y, y_pred)

        # 计算梯度
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # 更新权重
        self.optimizer.apply(gradients, trainable_vars)

        # 计算我们自己的度量
        self.loss_tracker.update_state(loss)
        self.mae_metric.update_state(y, y_pred)
        return {
            "loss": self.loss_tracker.result(),
            "mae": self.mae_metric.result(),
        }

    @property
    def metrics(self):
        # 我们在这里列出 `Metric` 对象,以便在每个 epoch 开始时
        # 或在 `evaluate()` 开始时可以自动调用 `reset_states()`。
        return [self.loss_tracker, self.mae_metric]


# 构造 CustomModel 的一个实例
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)

# 我们在这里不传递损失或度量。
model.compile(optimizer="adam")

# 像往常一样使用 `fit` -- 你可以使用回调等。
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.fit(x, y, epochs=5)
Epoch 1/5
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 4.0292 - mae: 1.9270
Epoch 2/5
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 385us/step - loss: 2.2155 - mae: 1.3920
Epoch 3/5
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 336us/step - loss: 1.1863 - mae: 0.9700
Epoch 4/5
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 373us/step - loss: 0.6510 - mae: 0.6811
Epoch 5/5
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 330us/step - loss: 0.4059 - mae: 0.5094

<keras.src.callbacks.history.History at 0x2a7a02860>

支持 sample_weightclass_weight

你可能注意到我们的第一个基本示例没有提到样本权重。如果你想支持 fit() 参数 sample_weightclass_weight,你只需执行以下操作:

  • data 参数中解包 sample_weight
  • 将其传递给 compute_lossupdate_state(当然,如果不依赖于 compile() 来处理损失和度量,你也可以手动应用)
  • 就这些。
class CustomModel(keras.Model):
    def train_step(self, data):
        # 解包数据。它的结构取决于您的模型和
        # 您传递给 `fit()` 的内容。
        if len(data) == 3:
            x, y, sample_weight = data
        else:
            sample_weight = None
            x, y = data

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)  # 向前传递
            # 计算损失值。
            # 损失函数在 `compile()` 中配置。
            loss = self.compute_loss(
                y=y,
                y_pred=y_pred,
                sample_weight=sample_weight,
            )

        # 计算梯度
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # 更新权重
        self.optimizer.apply(gradients, trainable_vars)

        # 更新指标。
        # 指标在 `compile()` 中配置。
        for metric in self.metrics:
            if metric.name == "loss":
                metric.update_state(loss)
            else:
                metric.update_state(y, y_pred, sample_weight=sample_weight)

        # 返回一个字典,将指标名称映射到当前值。
        # 注意,它将包含损失(跟踪在 self.metrics 中)。
        return {m.name: m.result() for m in self.metrics}


# 构建并编译 CustomModel 的实例
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(optimizer="adam", loss="mse", metrics=["mae"])

# 现在您可以使用 sample_weight 参数
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
sw = np.random.random((1000, 1))
model.fit(x, y, sample_weight=sw, epochs=3)
Epoch 1/3
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - mae: 0.4228 - loss: 0.1420
Epoch 2/3
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 449us/step - mae: 0.3751 - loss: 0.1058
Epoch 3/3
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 337us/step - mae: 0.3478 - loss: 0.0951

<keras.src.callbacks.history.History at 0x2a7491780>

提供您自己的评估步骤

如果您想对 model.evaluate() 的调用做同样的事情呢?那么您可以以完全相同的方式覆盖 test_step。下面是它的样子:

class CustomModel(keras.Model):
    def test_step(self, data):
        # 解包数据
        x, y = data
        # 计算预测值
        y_pred = self(x, training=False)
        # 更新跟踪损失的指标
        loss = self.compute_loss(y=y, y_pred=y_pred)
        # 更新指标。
        for metric in self.metrics:
            if metric.name == "loss":
                metric.update_state(loss)
            else:
                metric.update_state(y, y_pred)
        # 返回一个字典,映射指标名称到当前值。
        # 注意,它将包括损失(在 self.metrics 中跟踪)。
        return {m.name: m.result() for m in self.metrics}


# 构造 CustomModel 的实例
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(loss="mse", metrics=["mae"])

# 使用我们的自定义 test_step 进行评估
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.evaluate(x, y)
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 927us/step - mae: 0.8518 - loss: 0.9166

[0.912325382232666, 0.8567370176315308]

总结:一个完整的 GAN 示例

让我们走过一个完整的示例,利用您刚刚学到的一切。

让我们考虑:

  • 一个生成器网络,用于生成 28x28x1 的图像。
  • 一个判别器网络,用于将 28x28x1 的图像分类为两个类别(“假”和“真”)。
  • 每个网络一个优化器。
  • 一个用于训练判别器的损失函数。
# 创建判别器
discriminator = keras.Sequential(
    [
        keras.Input(shape=(28, 28, 1)),
        layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(negative_slope=0.2),
        layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(negative_slope=0.2),
        layers.GlobalMaxPooling2D(),
        layers.Dense(1),
    ],
    name="discriminator",
)

# 创建生成器
latent_dim = 128
generator = keras.Sequential(
    [
        keras.Input(shape=(latent_dim,)),
        # 我们想生成 128 个系数来重塑为 7x7x128 的图
        layers.Dense(7 * 7 * 128),
        layers.LeakyReLU(negative_slope=0.2),
        layers.Reshape((7, 7, 128)),
        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
        layers.LeakyReLU(negative_slope=0.2),
        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
        layers.LeakyReLU(negative_slope=0.2),
        layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"),
    ],
    name="generator",
)

这是一个功能完整的 GAN 类,重写 compile() 以使用其自己的签名,并在 train_step 中用 17 行实现整个 GAN 算法:

class GAN(keras.Model):
    def __init__(self, discriminator, generator, latent_dim):
        super().__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim
        self.d_loss_tracker = keras.metrics.Mean(name="d_loss")
        self.g_loss_tracker = keras.metrics.Mean(name="g_loss")
        self.seed_generator = keras.random.SeedGenerator(1337)

    @property
    def metrics(self):
        return [self.d_loss_tracker, self.g_loss_tracker]

    def compile(self, d_optimizer, g_optimizer, loss_fn):
        super().compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn

    def train_step(self, real_images):
        if isinstance(real_images, tuple):
            real_images = real_images[0]
        # 在潜在空间中采样随机点
        batch_size = tf.shape(real_images)[0]
        random_latent_vectors = keras.random.normal(
            shape=(batch_size, self.latent_dim), seed=self.seed_generator
        )

        # 解码为假图像
        generated_images = self.generator(random_latent_vectors)

        # 将它们与真实图像结合
        combined_images = tf.concat([generated_images, real_images], axis=0)

        # 组装标签以区分真实图像和假图像
        labels = tf.concat(
            [tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0
        )
        # 为标签添加随机噪声 - 重要技巧!
        labels += 0.05 * keras.random.uniform(
            tf.shape(labels), seed=self.seed_generator
        )

        # 训练判别器
        with tf.GradientTape() as tape:
            predictions = self.discriminator(combined_images)
            d_loss = self.loss_fn(labels, predictions)
        grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
        self.d_optimizer.apply(grads, self.discriminator.trainable_weights)

        # 在潜在空间中采样随机点
        random_latent_vectors = keras.random.normal(
            shape=(batch_size, self.latent_dim), seed=self.seed_generator
        )

        # 组装标签,表示“所有真实图像”
        misleading_labels = tf.zeros((batch_size, 1))

        # 训练生成器(注意我们*不*应该更新判别器的权重)!
        with tf.GradientTape() as tape:
            predictions = self.discriminator(self.generator(random_latent_vectors))
            g_loss = self.loss_fn(misleading_labels, predictions)
        grads = tape.gradient(g_loss, self.generator.trainable_weights)
        self.g_optimizer.apply(grads, self.generator.trainable_weights)

        # 更新指标并返回它们的值。
        self.d_loss_tracker.update_state(d_loss)
        self.g_loss_tracker.update_state(g_loss)
        return {
            "d_loss": self.d_loss_tracker.result(),
            "g_loss": self.g_loss_tracker.result(),
        }

让我们进行测试驱动:

# 准备数据集。我们使用训练和测试的 MNIST 数字。
batch_size = 64
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
all_digits = np.concatenate([x_train, x_test])
all_digits = all_digits.astype("float32") / 255.0
all_digits = np.reshape(all_digits, (-1, 28, 28, 1))
dataset = tf.data.Dataset.from_tensor_slices(all_digits)
dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)

gan = GAN(discriminator=discriminator, generator=generator, latent_dim=latent_dim)
gan.compile(
    d_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
    g_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
    loss_fn=keras.losses.BinaryCrossentropy(from_logits=True),
)

# 为了限制执行时间,我们仅训练 100 个批次。您可以在
# 整个数据集上进行训练。您需要大约 20 个周期才能获得良好的结果。
gan.fit(dataset.take(100), epochs=1)
 100/100 ━━━━━━━━━━━━━━━━━━━━ 51s 500ms/step - d_loss: 0.5645 - g_loss: 0.7434

<keras.src.callbacks.history.History at 0x14a4f1b10>

深度学习背后的理念很简单,那为什么它的实现会如此痛苦呢?