开发者指南 / 在 TensorFlow 中从头开始编写训练循环

在 TensorFlow 中从头开始编写训练循环

作者: fchollet
创建日期: 2019/03/01
最后修改: 2023/06/25
描述: 在 TensorFlow 中编写低级训练和评估循环。

在 Colab 中查看 GitHub 源代码


设置

import time
import os

# 本指南只能使用 TensorFlow 后端运行。
os.environ["KERAS_BACKEND"] = "tensorflow"

import tensorflow as tf
import keras
import numpy as np

介绍

Keras 提供了默认的训练和评估循环,fit()evaluate()。它们的使用在指南 使用内置方法进行训练和评估 中进行了介绍。

如果你想自定义模型的学习算法,同时仍然利用 fit() 的便利性(例如,使用 fit() 训练 GAN),你可以继承 Model 类并实现自己的 train_step() 方法,该方法会在 fit() 期间被反复调用。

现在,如果你想要非常低级别的训练和评估控制,你应该从头开始编写自己的训练和评估循环。这就是本指南的内容。


第一个端到端的示例

让我们考虑一个简单的 MNIST 模型:

def get_model():
    inputs = keras.Input(shape=(784,), name="digits")
    x1 = keras.layers.Dense(64, activation="relu")(inputs)
    x2 = keras.layers.Dense(64, activation="relu")(x1)
    outputs = keras.layers.Dense(10, name="predictions")(x2)
    model = keras.Model(inputs=inputs, outputs=outputs)
    return model


model = get_model()

让我们使用自定义训练循环和迷你批量梯度来训练它。

首先,我们需要一个优化器、一个损失函数和一个数据集:

# 实例化一个优化器。
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
# 实例化一个损失函数。
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

# 准备训练数据集。
batch_size = 32
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = np.reshape(x_train, (-1, 784))
x_test = np.reshape(x_test, (-1, 784))

# 保留10,000个样本用于验证。
x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_train = x_train[:-10000]
y_train = y_train[:-10000]

# 准备训练数据集。
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)

# 准备验证数据集。
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_dataset = val_dataset.batch(batch_size)

GradientTape 作用域内调用模型使您能够检索层可训练权重相对于损失值的梯度。使用优化器实例,您可以使用这些梯度来更新这些变量(您可以使用 model.trainable_weights 检索这些变量)。

以下是我们的训练循环,一步一步地:

  • 我们打开一个 for 循环,遍历多个 epoch
  • 对于每个 epoch,我们打开一个 for 循环,遍历数据集,按批次进行
  • 对于每个批次,我们打开一个 GradientTape() 作用域
  • 在这个作用域内,我们调用模型(前向传播)并计算损失
  • 在作用域外,我们检索模型权重相对于损失的梯度
  • 最后,我们使用优化器根据梯度更新模型的权重
epochs = 3
for epoch in range(epochs):
    print(f"\nStart of epoch {epoch}")

    # 遍历数据集的批次。
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        # 打开一个 GradientTape 以记录在正向传播期间运行的操作,这使得自动微分成为可能。
        with tf.GradientTape() as tape:
            # 运行层的前向传播。
            # 层应用于其输入的操作将被记录在 GradientTape 上。
            logits = model(x_batch_train, training=True)  # 这个小批量的 Logits

            # 计算这个小批量的损失值。
            loss_value = loss_fn(y_batch_train, logits)

        # 使用梯度带自动检索
        # 可训练变量相对于损失的梯度。
        grads = tape.gradient(loss_value, model.trainable_weights)

        # 通过更新
        # 变量的值以最小化损失来运行一步梯度下降。
        optimizer.apply(grads, model.trainable_weights)

        # 每 100 个批次记录一次。
        if step % 100 == 0:
            print(
                f"Training loss (for 1 batch) at step {step}: {float(loss_value):.4f}"
            )
            print(f"Seen so far: {(step + 1) * batch_size} samples")
开始第0轮
在步骤0的训练损失(对于1个批次):95.3300
目前为止看到的样本数:32个
在步骤100的训练损失(对于1个批次):2.5622
目前为止看到的样本数:3232个
在步骤200的训练损失(对于1个批次):3.1138
目前为止看到的样本数:6432个
在步骤300的训练损失(对于1个批次):0.6748
目前为止看到的样本数:9632个
在步骤400的训练损失(对于1个批次):1.3308
目前为止看到的样本数:12832个
在步骤500的训练损失(对于1个批次):1.9813
目前为止看到的样本数:16032个
在步骤600的训练损失(对于1个批次):0.8640
目前为止看到的样本数:19232个
在步骤700的训练损失(对于1个批次):1.0696
目前为止看到的样本数:22432个
在步骤800的训练损失(对于1个批次):0.3662
目前为止看到的样本数:25632个
在步骤900的训练损失(对于1个批次):0.9556
目前为止看到的样本数:28832个
在步骤1000的训练损失(对于1个批次):0.7459
目前为止看到的样本数:32032个
在步骤1100的训练损失(对于1个批次):0.0468
目前为止看到的样本数:35232个
在步骤1200的训练损失(对于1个批次):0.7392
目前为止看到的样本数:38432个
在步骤1300的训练损失(对于1个批次):0.8435
目前为止看到的样本数:41632个
在步骤1400的训练损失(对于1个批次):0.3859
目前为止看到的样本数:44832个
在步骤1500的训练损失(对于1个批次):0.4156
目前为止看到的样本数:48032个
Start of epoch 1
Training loss (for 1 batch) at step 0: 0.4045
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 0.5983
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 0.3154
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 0.7911
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 0.2607
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 0.2303
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 0.6048
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 0.7041
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 0.3669
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 0.6389
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 0.7739
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 0.3888
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 0.8133
Seen so far: 38432 samples
Training loss (for 1 batch) at step 1300: 0.2034
Seen so far: 41632 samples
Training loss (for 1 batch) at step 1400: 0.0768
Seen so far: 44832 samples
Training loss (for 1 batch) at step 1500: 0.1544
Seen so far: 48032 samples
Start of epoch 2
Training loss (for 1 batch) at step 0: 0.1250
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 0.0152
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 0.0917
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 0.1330
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 0.0884
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 0.2656
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 0.4375
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 0.2246
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 0.0748
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 0.1765
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 0.0130
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 0.4030
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 0.0667
Seen so far: 38432 samples
Training loss (for 1 batch) at step 1300: 1.0553
Seen so far: 41632 samples
Training loss (for 1 batch) at step 1400: 0.6513
Seen so far: 44832 samples
Training loss (for 1 batch) at step 1500: 0.0599
Seen so far: 48032 samples

低级处理指标

让我们在这个基本的循环中添加指标监控。

你可以在这种从头编写的训练循环中直接重用内置的指标(或你自己编写的自定义指标)。流程如下:

  • 在循环开始时实例化指标
  • 每批数据后调用 metric.update_state()
  • 当你需要显示指标的当前值时调用 metric.result()
  • 当你需要清除指标的状态时(通常在一个epoch结束时),调用 metric.reset_state()

让我们利用这些知识在每个epoch结束时计算训练和验证数据的 SparseCategoricalAccuracy

# 获取一个新模型
model = get_model()

# 实例化一个优化器来训练模型。
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
# 实例化一个损失函数。
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

# 准备指标。
train_acc_metric = keras.metrics.SparseCategoricalAccuracy()
val_acc_metric = keras.metrics.SparseCategoricalAccuracy()

这是我们的训练和评估循环:

epochs = 2
for epoch in range(epochs):
    print(f"\nStart of epoch {epoch}")
    start_time = time.time()

    # 遍历数据集的批次。
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        with tf.GradientTape() as tape:
            logits = model(x_batch_train, training=True)
            loss_value = loss_fn(y_batch_train, logits)
        grads = tape.gradient(loss_value, model.trainable_weights)
        optimizer.apply(grads, model.trainable_weights)

        # 更新训练指标。
        train_acc_metric.update_state(y_batch_train, logits)

        # 每100个批次记录一次。
        if step % 100 == 0:
            print(
                f"Training loss (for 1 batch) at step {step}: {float(loss_value):.4f}"
            )
            print(f"Seen so far: {(step + 1) * batch_size} samples")

    # 在每个epoch结束时显示指标。
    train_acc = train_acc_metric.result()
    print(f"Training acc over epoch: {float(train_acc):.4f}")

    # 在每个epoch结束时重置训练指标
    train_acc_metric.reset_state()

    # 在每个epoch结束时运行一个验证循环。
    for x_batch_val, y_batch_val in val_dataset:
        val_logits = model(x_batch_val, training=False)
        # 更新验证指标
        val_acc_metric.update_state(y_batch_val, val_logits)
    val_acc = val_acc_metric.result()
    val_acc_metric.reset_state()
    print(f"Validation acc: {float(val_acc):.4f}")
    print(f"Time taken: {time.time() - start_time:.2f}s")
Start of epoch 0
Training loss (for 1 batch) at step 0: 89.1303
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 1.0351
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 2.9143
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 1.7842
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 0.9583
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 1.1100
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 2.1144
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 0.6801
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 0.6202
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 1.2570
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 0.3638
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 1.8402
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 0.7836
Seen so far: 38432 samples
Training loss (for 1 batch) at step 1300: 0.5147
Seen so far: 41632 samples
Training loss (for 1 batch) at step 1400: 0.4798
Seen so far: 44832 samples
Training loss (for 1 batch) at step 1500: 0.1653
Seen so far: 48032 samples
Training acc over epoch: 0.7961
Validation acc: 0.8825
Time taken: 46.06s
Start of epoch 1
Training loss (for 1 batch) at step 0: 1.3917
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 0.2600
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 0.7206
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 0.4987
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 0.3410
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 0.6788
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 1.1355
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 0.1762
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 0.1801
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 0.3515
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 0.4344
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 0.2027
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 0.4649
Seen so far: 38432 samples
Training loss (for 1 batch) at step 1300: 0.6848
Seen so far: 41632 samples
Training loss (for 1 batch) at step 1400: 0.4594
Seen so far: 44832 samples
Training loss (for 1 batch) at step 1500: 0.3548
Seen so far: 48032 samples
Training acc over epoch: 0.8896
Validation acc: 0.9094
Time taken: 43.49s
--- ## 使用 [`tf.function`](https://www.tensorflow.org/api_docs/python/tf/function) 加速你的训练步骤 TensorFlow 中的默认运行时是 eager execution。 因此,我们上面的训练循环是 eager 执行的。 这对于调试非常有用,但图编译具有明显的性能优势。将你的计算描述为静态图可以使框架应用全局性能优化。当框架被迫贪婪地逐个执行操作时,这是不可能的,框架不知道接下来会发生什么。 你可以将任何以张量为输入的函数编译成静态图。只需在其上添加一个 `@tf.function` 装饰器,如下所示:
@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        logits = model(x, training=True)
        loss_value = loss_fn(y, logits)
    grads = tape.gradient(loss_value, model.trainable_weights)
    optimizer.apply(grads, model.trainable_weights)
    train_acc_metric.update_state(y, logits)
    return loss_value
让我们对评估步骤做同样的事情:
@tf.function
def test_step(x, y):
    val_logits = model(x, training=False)
    val_acc_metric.update_state(y, val_logits)
现在,让我们用这个编译的训练步骤重新运行我们的训练循环:
epochs = 2
for epoch in range(epochs):
    print(f"\nStart of epoch {epoch}")
    start_time = time.time()

    # 遍历数据集的批次。
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        loss_value = train_step(x_batch_train, y_batch_train)

        # 每100个批次记录一次。
        if step % 100 == 0:
            print(
                f"Training loss (for 1 batch) at step {step}: {float(loss_value):.4f}"
            )
            print(f"Seen so far: {(step + 1) * batch_size} samples")

    # 在每个epoch结束时显示指标。
    train_acc = train_acc_metric.result()
    print(f"Training acc over epoch: {float(train_acc):.4f}")

    # 在每个epoch结束时重置训练指标
    train_acc_metric.reset_state()

    # 在每个epoch结束时运行验证循环。
    for x_batch_val, y_batch_val in val_dataset:
        test_step(x_batch_val, y_batch_val)

    val_acc = val_acc_metric.result()
    val_acc_metric.reset_state()
    print(f"Validation acc: {float(val_acc):.4f}")
    print(f"Time taken: {time.time() - start_time:.2f}s")
Start of epoch 0
Training loss (for 1 batch) at step 0: 0.5366
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 0.2732
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 0.2478
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 0.0263
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 0.4845
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 0.2239
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 0.2242
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 0.2122
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 0.2856
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 0.1957
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 0.2946
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 0.3080
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 0.2326
Seen so far: 38432 samples
Training loss (for 1 batch) at step 1300: 0.6514
Seen so far: 41632 samples
Training loss (for 1 batch) at step 1400: 0.2018
Seen so far: 44832 samples
Training loss (for 1 batch) at step 1500: 0.2812
Seen so far: 48032 samples
Training acc over epoch: 0.9104
Validation acc: 0.9199
Time taken: 5.73s
Start of epoch 1
Training loss (for 1 batch) at step 0: 0.3080
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 0.3943
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 0.1657
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 0.1463
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 0.5359
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 0.1894
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 0.1801
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 0.1724
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 0.3997
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 0.6017
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 0.1539
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 0.1078
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 0.8731
Seen so far: 38432 samples
Training loss (for 1 batch) at step 1300: 0.3110
Seen so far: 41632 samples
Training loss (for 1 batch) at step 1400: 0.6092
Seen so far: 44832 samples
Training loss (for 1 batch) at step 1500: 0.2046
Seen so far: 48032 samples
Training acc over epoch: 0.9189
Validation acc: 0.9358
Time taken: 3.17s
快多了,不是吗? --- ## 低级处理模型跟踪的损失 层和模型会递归跟踪在正向传递过程中由调用 `self.add_loss(value)` 的层创建的任何损失。 通过属性 `model.losses` 可以获得这些标量损失值的列表 在forward pass的末尾。 如果你想使用这些损失组件,你应该将它们求和 并在你的训练步骤中将它们添加到主损失中。 考虑这个层,它创建了一个活动正则化损失:
class ActivityRegularizationLayer(keras.layers.Layer):
    def call(self, inputs):
        self.add_loss(1e-2 * tf.reduce_sum(inputs))
        return inputs
让我们构建一个非常简单的使用它的模型:
inputs = keras.Input(shape=(784,), name="digits")
x = keras.layers.Dense(64, activation="relu")(inputs)
# 插入活动正则化作为一层
x = ActivityRegularizationLayer()(x)
x = keras.layers.Dense(64, activation="relu")(x)
outputs = keras.layers.Dense(10, name="predictions")(x)

model = keras.Model(inputs=inputs, outputs=outputs)
这是我们的训练步骤应该看起来的样子:
@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        logits = model(x, training=True)
        loss_value = loss_fn(y, logits)
        # 添加在forward pass期间创建的任何额外损失。
        loss_value += sum(model.losses)
    grads = tape.gradient(loss_value, model.trainable_weights)
    optimizer.apply(grads, model.trainable_weights)
    train_acc_metric.update_state(y, logits)
    return loss_value
--- ## 总结 现在你知道了关于使用内置训练循环和 从头编写自己的训练循环的一切。 总结一下,这里是一个简单的端到端示例,它将本指南中 学到的所有内容结合在一起:一个在MNIST数字上训练的DCGAN。 --- ## 端到端示例:从头开始的GAN训练循环 你可能熟悉生成对抗网络(GANs)。GANs可以生成新的 看起来几乎真实的图像,通过学习训练 图像数据集的潜在分布(图像的“潜在空间”)。 一个GAN由两部分组成:一个将潜在空间中的点映射到图像空间中的点的“生成器”模型,一个“判别器”模型,一个可以区分真实图像(来自训练数据集)和假图像(生成器网络的输出)的分类器。 一个GAN训练循环看起来像这样: 1) 训练判别器。 - 在潜在空间中采样一批随机点。 - 通过“生成器”模型将点变成假图像。 - 获取一批真实图像并将它们与生成的图像结合。 - 训练“判别器”模型以分类生成的图像与真实图像。 2) 训练生成器。 - 在潜在空间中采样随机点。 - 通过“生成器”网络将点变成假图像。 - 获取一批真实图像并将它们与生成的图像结合。 - 训练“生成器”模型以“欺骗”判别器并将假图像分类为真实。 有关GANs工作原理的更详细概述,请参见 [Deep Learning with Python](https://www.manning.com/books/deep-learning-with-python)。 让我们实现这个训练循环。首先,创建用于分类假数字和真实数字的判别器:
discriminator = keras.Sequential(
    [
        keras.Input(shape=(28, 28, 1)),
        keras.layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"),
        keras.layers.LeakyReLU(negative_slope=0.2),
        keras.layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"),
        keras.layers.LeakyReLU(negative_slope=0.2),
        keras.layers.GlobalMaxPooling2D(),
        keras.layers.Dense(1),
    ],
    name="discriminator",
)
discriminator.summary()
Model: "discriminator"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
┃ Layer (type)                     Output Shape                  Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
│ conv2d (Conv2D)                 │ (None, 14, 14, 64)        │        640 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ leaky_re_lu (LeakyReLU)         │ (None, 14, 14, 64)        │          0 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ conv2d_1 (Conv2D)               │ (None, 7, 7, 128)         │     73,856 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ leaky_re_lu_1 (LeakyReLU)       │ (None, 7, 7, 128)         │          0 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ global_max_pooling2d            │ (None, 128)               │          0 │
│ (GlobalMaxPooling2D)            │                           │            │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ dense_6 (Dense)                 │ (None, 1)                 │        129 │
└─────────────────────────────────┴───────────────────────────┴────────────┘
 Total params: 74,625 (291.50 KB)
 Trainable params: 74,625 (291.50 KB)
 Non-trainable params: 0 (0.00 B)
Then let's create a generator network, that turns latent vectors into outputs of shape `(28, 28, 1)` (representing MNIST digits):
latent_dim = 128

generator = keras.Sequential(
    [
        keras.Input(shape=(latent_dim,)),
        # 我们希望生成128个系数以重塑为7x7x128的映射
        keras.layers.Dense(7 * 7 * 128),
        keras.layers.LeakyReLU(negative_slope=0.2),
        keras.layers.Reshape((7, 7, 128)),
        keras.layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
        keras.layers.LeakyReLU(negative_slope=0.2),
        keras.layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
        keras.layers.LeakyReLU(negative_slope=0.2),
        keras.layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"),
    ],
    name="generator",
)
这里是关键部分:训练循环。如你所见,它非常直接。训练步骤函数仅用了17行代码。
# 实例化一个优化器用于判别器,另一个用于生成器。
d_optimizer = keras.optimizers.Adam(learning_rate=0.0003)
g_optimizer = keras.optimizers.Adam(learning_rate=0.0004)

# 实例化一个损失函数。
loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)


@tf.function
def train_step(real_images):
    # 在潜在空间中随机采样点
    random_latent_vectors = tf.random.normal(shape=(batch_size, latent_dim))
    # 将它们解码为假图像
    generated_images = generator(random_latent_vectors)
    # 将它们与真实图像结合
    combined_images = tf.concat([generated_images, real_images], axis=0)

    # 组装标签区分真实图像和假图像
    labels = tf.concat(
        [tf.ones((batch_size, 1)), tf.zeros((real_images.shape[0], 1))], axis=0
    )
    # 向标签添加随机噪声 - 重要的技巧!
    labels += 0.05 * tf.random.uniform(labels.shape)

    # 训练判别器
    with tf.GradientTape() as tape:
        predictions = discriminator(combined_images)
        d_loss = loss_fn(labels, predictions)
    grads = tape.gradient(d_loss, discriminator.trainable_weights)
    d_optimizer.apply(grads, discriminator.trainable_weights)

    # 在潜在空间中随机采样点
    random_latent_vectors = tf.random.normal(shape=(batch_size, latent_dim))
    # 组装标签表示“所有真实图像”
    misleading_labels = tf.zeros((batch_size, 1))

    # 训练生成器(注意我们不应该更新判别器的权重)
    with tf.GradientTape() as tape:
        predictions = discriminator(generator(random_latent_vectors))
        g_loss = loss_fn(misleading_labels, predictions)
    grads = tape.gradient(g_loss, generator.trainable_weights)
    g_optimizer.apply(grads, generator.trainable_weights)
    return d_loss, g_loss, generated_images
让我们通过反复调用 `train_step` 在图像批次上训练我们的GAN。 由于我们的判别器和生成器是卷积网络,你可能会想在GPU上运行这段代码。
# 准备数据集。我们使用训练和测试的MNIST数字。
batch_size = 64
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
all_digits = np.concatenate([x_train, x_test])
all_digits = all_digits.astype("float32") / 255.0
all_digits = np.reshape(all_digits, (-1, 28, 28, 1))
dataset = tf.data.Dataset.from_tensor_slices(all_digits)
dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)

epochs = 1  # 实际上,你需要至少20个epoch来生成好的数字。
save_dir = "./"

for epoch in range(epochs):
    print(f"\nStart epoch {epoch}")

    for step, real_images in enumerate(dataset):
        # 在一个批次的真实图像上训练判别器和生成器。
        d_loss, g_loss, generated_images = train_step(real_images)

        # 日志记录。
        if step % 100 == 0:
            # 打印指标
            print(f"discriminator loss at step {step}: {d_loss:.2f}")
            print(f"adversarial loss at step {step}: {g_loss:.2f}")

            # 保存一个生成的图像
            img = keras.utils.array_to_img(generated_images[0] * 255.0, scale=False)
            img.save(os.path.join(save_dir, f"generated_img_{step}.png"))

        # 为了限制执行时间,我们在10步后停止。
        # 删除以下行以实际训练模型!
        if step > 10:
            break
开始第 0 轮
判别器在第 0 步的损失: 0.69
对抗损失在第 0 步: 0.69
就是这样!在 Colab GPU 上训练大约 30 秒后,你将得到看起来很不错的假 MNIST 数字。