Model.train_step
作者: A_K_Nain
创建日期: 2020/05/9
最后修改: 2023/08/3
描述: 带有梯度惩罚的 Wasserstein GAN 实现。
原始的 Wasserstein GAN 利用 Wasserstein 距离生成具有更好理论性质的值函数,而不是原始 GAN 论文中使用的值函数。WGAN 要求鉴别器(即评论家)位于 1-Lipschitz 函数的空间内。作者提出了权重截断的概念来实现这一约束。尽管权重截断有效,但作为强制 1-Lipschitz 约束的一种方式,它可能会造成问题,例如,深层 WGAN 鉴别器(评论家)往往无法收敛。
WGAN-GP 方法提出了替代权重截断以确保平滑训练的方法。作者建议通过添加一个损失项来实现“梯度惩罚”,使鉴别器梯度的 L2 范数保持接近 1。
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import keras
import tensorflow as tf
from keras import layers
为了演示如何训练 WGAN-GP,我们将使用 Fashion-MNIST 数据集。该数据集中的每个样本都是一个 28x28 的灰度图像,与来自 10 个类别的标签相关联(例如,裤子、套头衫、运动鞋等)。
IMG_SHAPE = (28, 28, 1)
BATCH_SIZE = 512
# 噪声向量的大小
noise_dim = 128
fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
print(f"示例数量: {len(train_images)}")
print(f"数据集中图像的形状: {train_images.shape[1:]}")
# 将每个样本重塑为 (28, 28, 1) 并将像素值规范化到 [-1, 1] 范围
train_images = train_images.reshape(train_images.shape[0], *IMG_SHAPE).astype("float32")
train_images = (train_images - 127.5) / 127.5
从 https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz 下载数据
29515/29515 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step
从 https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz 下载数据
26421880/26421880 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step
从 https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz 下载数据
5148/5148 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step
从 https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz 下载数据
4422102/4422102 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step
示例数量: 60000
数据集中图像的形状: (28, 28)
数据集中的样本具有 (28, 28, 1) 的形状。由于我们将使用步幅卷积,这可能导致形状为奇数尺寸。例如,
(28, 28) -> Conv_s2 -> (14, 14) -> Conv_s2 -> (7, 7) -> Conv_s2 ->(3, 3)
。
在网络的生成器部分执行上采样时,如果我们不小心,将不会得到与原始图像相同的输入形状。为避免这种情况,我们将做一些更简单的事情:
- 在鉴别器中:对输入进行“零填充”,将每个样本的形状更改为 (32, 32, 1)
;以及
- 在生成器中:裁剪最终输出以匹配输入形状。
def conv_block(
x,
filters,
activation,
kernel_size=(3, 3),
strides=(1, 1),
padding="same",
use_bias=True,
use_bn=False,
use_dropout=False,
drop_value=0.5,
):
x = layers.Conv2D(
filters, kernel_size, strides=strides, padding=padding, use_bias=use_bias
)(x)
if use_bn:
x = layers.BatchNormalization()(x)
x = activation(x)
if use_dropout:
x = layers.Dropout(drop_value)(x)
return x
def get_discriminator_model():
img_input = layers.Input(shape=IMG_SHAPE)
# 对输入进行零填充,使输入图像大小为 (32, 32, 1)。
x = layers.ZeroPadding2D((2, 2))(img_input)
x = conv_block(
x,
64,
kernel_size=(5, 5),
strides=(2, 2),
use_bn=False,
use_bias=True,
activation=layers.LeakyReLU(0.2),
use_dropout=False,
drop_value=0.3,
)
x = conv_block(
x,
128,
kernel_size=(5, 5),
strides=(2, 2),
use_bn=False,
activation=layers.LeakyReLU(0.2),
use_bias=True,
use_dropout=True,
drop_value=0.3,
)
x = conv_block(
x,
256,
kernel_size=(5, 5),
strides=(2, 2),
use_bn=False,
activation=layers.LeakyReLU(0.2),
use_bias=True,
use_dropout=True,
drop_value=0.3,
)
x = conv_block(
x,
512,
kernel_size=(5, 5),
strides=(2, 2),
use_bn=False,
activation=layers.LeakyReLU(0.2),
use_bias=True,
use_dropout=False,
drop_value=0.3,
)
x = layers.Flatten()(x)
x = layers.Dropout(0.2)(x)
x = layers.Dense(1)(x)
d_model = keras.models.Model(img_input, x, name="discriminator")
return d_model
d_model = get_discriminator_model()
d_model.summary()
模型: "判别器"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓ ┃ 层(类型) ┃ 输出形状 ┃ 参数 # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩ │ input_layer (输入层) │ (无, 28, 28, 1) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ zero_padding2d (零填充层) │ (无, 32, 32, 1) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv2d (卷积层) │ (无, 16, 16, 64) │ 1,664 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ leaky_re_lu (泄露ReLU) │ (无, 16, 16, 64) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv2d_1 (卷积层) │ (无, 8, 8, 128) │ 204,928 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ leaky_re_lu_1 (泄露ReLU) │ (无, 8, 8, 128) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ dropout (丢弃层) │ (无, 8, 8, 128) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv2d_2 (卷积层) │ (无, 4, 4, 256) │ 819,456 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ leaky_re_lu_2 (泄漏ReLU) │ (无, 4, 4, 256) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ dropout_1 (丢弃) │ (无, 4, 4, 256) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv2d_3 (卷积层) │ (无, 2, 2, 512) │ 3,277,312 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ leaky_re_lu_3 (泄漏ReLU) │ (无, 2, 2, 512) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ flatten (展平) │ (无, 2048) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ dropout_2 (丢弃) │ (无, 2048) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ dense (全连接层) │ (无, 1) │ 2,049 │ └─────────────────────────────────┴───────────────────────────┴────────────┘
总参数: 4,305,409 (16.42 MB)
可训练参数: 4,305,409 (16.42 MB)
不可训练参数: 0 (0.00 B)
def upsample_block(
x,
filters,
activation,
kernel_size=(3, 3),
strides=(1, 1),
up_size=(2, 2),
padding="same",
use_bn=False,
use_bias=True,
use_dropout=False,
drop_value=0.3,
):
x = layers.UpSampling2D(up_size)(x)
x = layers.Conv2D(
filters, kernel_size, strides=strides, padding=padding, use_bias=use_bias
)(x)
if use_bn:
x = layers.BatchNormalization()(x)
if activation:
x = activation(x)
if use_dropout:
x = layers.Dropout(drop_value)(x)
return x
def get_generator_model():
noise = layers.Input(shape=(noise_dim,))
x = layers.Dense(4 * 4 * 256, use_bias=False)(noise)
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Reshape((4, 4, 256))(x)
x = upsample_block(
x,
128,
layers.LeakyReLU(0.2),
strides=(1, 1),
use_bias=False,
use_bn=True,
padding="same",
use_dropout=False,
)
x = upsample_block(
x,
64,
layers.LeakyReLU(0.2),
strides=(1, 1),
use_bias=False,
use_bn=True,
padding="same",
use_dropout=False,
)
x = upsample_block(
x, 1, layers.Activation("tanh"), strides=(1, 1), use_bias=False, use_bn=True
)
# At this point, we have an output which has the same shape as the input, (32, 32, 1).
# We will use a Cropping2D layer to make it (28, 28, 1).
x = layers.Cropping2D((2, 2))(x)
g_model = keras.models.Model(noise, x, name="generator")
return g_model
g_model = get_generator_model()
g_model.summary()
Model: "generator"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩ │ input_layer_1 (InputLayer) │ (None, 128) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ dense_1 (Dense) │ (None, 4096) │ 524,288 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ batch_normalization │ (None, 4096) │ 16,384 │ │ (BatchNormalization) │ │ │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ leaky_re_lu_4 (LeakyReLU) │ (None, 4096) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ reshape (Reshape) │ (None, 4, 4, 256) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ up_sampling2d (UpSampling2D) │ (None, 8, 8, 256) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv2d_4 (Conv2D) │ (None, 8, 8, 128) │ 294,912 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ batch_normalization_1 │ (None, 8, 8, 128) │ 512 │ │ (BatchNormalization) │ │ │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ leaky_re_lu_5 (LeakyReLU) │ (None, 8, 8, 128) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ up_sampling2d_1 (UpSampling2D) │ (None, 16, 16, 128) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv2d_5 (Conv2D) │ (None, 16, 16, 64) │ 73,728 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ batch_normalization_2 │ (None, 16, 16, 64) │ 256 │ │ (BatchNormalization) │ │ │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ leaky_re_lu_6 (LeakyReLU) │ (None, 16, 16, 64) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ up_sampling2d_2 (UpSampling2D) │ (无, 32, 32, 64) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv2d_6 (Conv2D) │ (无, 32, 32, 1) │ 576 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ batch_normalization_3 │ (无, 32, 32, 1) │ 4 │ │ (BatchNormalization) │ │ │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ activation (Activation) │ (无, 32, 32, 1) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ cropping2d (Cropping2D) │ (无, 28, 28, 1) │ 0 │ └─────────────────────────────────┴───────────────────────────┴────────────┘
总参数: 910,660 (3.47 MB)
可训练参数: 902,082 (3.44 MB)
不可训练参数: 8,578 (33.51 KB)
现在我们已经定义了生成器和判别器,是时候实现 WGAN-GP 模型了。我们还将重写 train_step
以进行训练。
class WGAN(keras.Model):
def __init__(
self,
discriminator,
generator,
latent_dim,
discriminator_extra_steps=3,
gp_weight=10.0,
):
super().__init__()
self.discriminator = discriminator
self.generator = generator
self.latent_dim = latent_dim
self.d_steps = discriminator_extra_steps
self.gp_weight = gp_weight
def compile(self, d_optimizer, g_optimizer, d_loss_fn, g_loss_fn):
super().compile()
self.d_optimizer = d_optimizer
self.g_optimizer = g_optimizer
self.d_loss_fn = d_loss_fn
self.g_loss_fn = g_loss_fn
def gradient_penalty(self, batch_size, real_images, fake_images):
"""计算梯度惩罚。
此损失是在插值图像上计算的
并添加到鉴别器损失中。
"""
# 获取插值图像
alpha = tf.random.uniform([batch_size, 1, 1, 1], 0.0, 1.0)
diff = fake_images - real_images
interpolated = real_images + alpha * diff
with tf.GradientTape() as gp_tape:
gp_tape.watch(interpolated)
# 1. 获取此插值图像的鉴别器输出。
pred = self.discriminator(interpolated, training=True)
# 2. 计算该插值图像的梯度。
grads = gp_tape.gradient(pred, [interpolated])[0]
# 3. 计算梯度的范数。
norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
gp = tf.reduce_mean((norm - 1.0) ** 2)
return gp
def train_step(self, real_images):
if isinstance(real_images, tuple):
real_images = real_images[0]
# 获取批量大小
batch_size = tf.shape(real_images)[0]
# 对于每个批次,我们将按照原始论文中列出的步骤执行:
# 1. 训练生成器并获取生成器损失
# 2. 训练鉴别器并获取鉴别器损失
# 3. 计算梯度惩罚
# 4. 用一个常数权重因子乘以这个梯度惩罚
# 5. 将梯度惩罚添加到鉴别器损失中
# 6. 将生成器和鉴别器损失作为损失字典返回
# 首先训练鉴别器。原始论文建议将
# 鉴别器训练 `x` 次步骤(通常为 5),
# 相对于生成器的一次步骤。在这里,我们将训练
# 3 次额外步骤,相对于 5 次,以减少训练时间。
for i in range(self.d_steps):
# 获取潜在向量
random_latent_vectors = tf.random.normal(
shape=(batch_size, self.latent_dim)
)
with tf.GradientTape() as tape:
# 从潜在向量生成假图像
fake_images = self.generator(random_latent_vectors, training=True)
# 获取假图像的 logits
fake_logits = self.discriminator(fake_images, training=True)
# 获取真实图像的 logits
real_logits = self.discriminator(real_images, training=True)
# 使用假图像和真实图像的 logits 计算鉴别器损失
d_cost = self.d_loss_fn(real_img=real_logits, fake_img=fake_logits)
# 计算梯度惩罚
gp = self.gradient_penalty(batch_size, real_images, fake_images)
# 将梯度惩罚添加到原始鉴别器损失
d_loss = d_cost + gp * self.gp_weight
# 获取相对于鉴别器损失的梯度
d_gradient = tape.gradient(d_loss, self.discriminator.trainable_variables)
# 使用鉴别器优化器更新鉴别器的权重
self.d_optimizer.apply_gradients(
zip(d_gradient, self.discriminator.trainable_variables)
)
# 训练生成器
# 获取潜在向量
random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
with tf.GradientTape() as tape:
# 使用生成器生成假图像
generated_images = self.generator(random_latent_vectors, training=True)
# 获取假图像的鉴别器 logits
gen_img_logits = self.discriminator(generated_images, training=True)
# 计算生成器损失
g_loss = self.g_loss_fn(gen_img_logits)
# 获取相对于生成器损失的梯度
gen_gradient = tape.gradient(g_loss, self.generator.trainable_variables)
# 使用生成器优化器更新生成器的权重
self.g_optimizer.apply_gradients(
zip(gen_gradient, self.generator.trainable_variables)
)
return {"d_loss": d_loss, "g_loss": g_loss}
class GANMonitor(keras.callbacks.Callback):
def __init__(self, num_img=6, latent_dim=128):
self.num_img = num_img
self.latent_dim = latent_dim
def on_epoch_end(self, epoch, logs=None):
random_latent_vectors = tf.random.normal(shape=(self.num_img, self.latent_dim))
generated_images = self.model.generator(random_latent_vectors)
generated_images = (generated_images * 127.5) + 127.5
for i in range(self.num_img):
img = generated_images[i].numpy()
img = keras.utils.array_to_img(img)
img.save("generated_img_{i}_{epoch}.png".format(i=i, epoch=epoch))
# 为两个网络实例化优化器
# (建议使用 learning_rate=0.0002,beta_1=0.5)
generator_optimizer = keras.optimizers.Adam(
learning_rate=0.0002, beta_1=0.5, beta_2=0.9
)
discriminator_optimizer = keras.optimizers.Adam(
learning_rate=0.0002, beta_1=0.5, beta_2=0.9
)
# 定义判别器的损失函数,
# 应该是 (假图像损失 - 真实图像损失)。
# 我们稍后将为这个损失函数添加梯度惩罚。
def discriminator_loss(real_img, fake_img):
real_loss = tf.reduce_mean(real_img)
fake_loss = tf.reduce_mean(fake_img)
return fake_loss - real_loss
# 定义生成器的损失函数。
def generator_loss(fake_img):
return -tf.reduce_mean(fake_img)
# 设置训练的纪元数量。
epochs = 20
# 实例化客户 `GANMonitor` Keras 回调。
cbk = GANMonitor(num_img=3, latent_dim=noise_dim)
# 获取 wgan 模型
wgan = WGAN(
discriminator=d_model,
generator=g_model,
latent_dim=noise_dim,
discriminator_extra_steps=3,
)
# 编译 wgan 模型
wgan.compile(
d_optimizer=discriminator_optimizer,
g_optimizer=generator_optimizer,
g_loss_fn=generator_loss,
d_loss_fn=discriminator_loss,
)
# 开始训练
wgan.fit(train_images, batch_size=BATCH_SIZE, epochs=epochs, callbacks=[cbk])
Epoch 1/20
118/118 ━━━━━━━━━━━━━━━━━━━━ 79s 345ms/step - d_loss: -7.7597 - g_loss: -17.2858 - loss: 0.0000e+00
Epoch 2/20
118/118 ━━━━━━━━━━━━━━━━━━━━ 14s 118ms/step - d_loss: -7.0841 - g_loss: -13.8542 - loss: 0.0000e+00
Epoch 3/20
118/118 ━━━━━━━━━━━━━━━━━━━━ 14s 118ms/step - d_loss: -6.1011 - g_loss: -13.2763 - loss: 0.0000e+00
Epoch 4/20
118/118 ━━━━━━━━━━━━━━━━━━━━ 14s 119ms/step - d_loss: -5.5292 - g_loss: -13.3122 - loss: 0.0000e+00
Epoch 5/20
118/118 ━━━━━━━━━━━━━━━━━━━━ 14s 119ms/step - d_loss: -5.1012 - g_loss: -12.1395 - loss: 0.0000e+00
Epoch 6/20
118/118 ━━━━━━━━━━━━━━━━━━━━ 14s 119ms/step - d_loss: -4.7557 - g_loss: -11.2559 - loss: 0.0000e+00
Epoch 7/20
118/118 ━━━━━━━━━━━━━━━━━━━━ 14s 119ms/step - d_loss: -4.4727 - g_loss: -10.3075 - loss: 0.0000e+00
Epoch 8/20
118/118 ━━━━━━━━━━━━━━━━━━━━ 14s 119ms/step - d_loss: -4.2056 - g_loss: -10.0340 - loss: 0.0000e+00
Epoch 9/20
118/118 ━━━━━━━━━━━━━━━━━━━━ 14s 120ms/step - d_loss: -4.0116 - g_loss: -9.9283 - loss: 0.0000e+00
Epoch 10/20
118/118 ━━━━━━━━━━━━━━━━━━━━ 14s 120ms/step - d_loss: -3.8050 - g_loss: -9.7392 - loss: 0.0000e+00
Epoch 11/20
118/118 ━━━━━━━━━━━━━━━━━━━━ 14s 120ms/step - d_loss: -3.6608 - g_loss: -9.4686 - loss: 0.0000e+00
Epoch 12/20
118/118 ━━━━━━━━━━━━━━━━━━━━ 14s 121ms/step - d_loss: -3.4623 - g_loss: -8.9601 - loss: 0.0000e+00
Epoch 13/20
118/118 ━━━━━━━━━━━━━━━━━━━━ 14s 120ms/step - d_loss: -3.3659 - g_loss: -8.4620 - loss: 0.0000e+00
Epoch 14/20
118/118 ━━━━━━━━━━━━━━━━━━━━ 14s 120ms/step - d_loss: -3.2486 - g_loss: -7.9598 - loss: 0.0000e+00
Epoch 15/20
118/118 ━━━━━━━━━━━━━━━━━━━━ 14s 120ms/step - d_loss: -3.1436 - g_loss: -7.5392 - loss: 0.0000e+00
Epoch 16/20
118/118 ━━━━━━━━━━━━━━━━━━━━ 14s 120ms/step - d_loss: -3.0370 - g_loss: -7.3694 - loss: 0.0000e+00
Epoch 17/20
118/118 ━━━━━━━━━━━━━━━━━━━━ 14s 120ms/step - d_loss: -2.9256 - g_loss: -7.6105 - loss: 0.0000e+00
Epoch 18/20
118/118 ━━━━━━━━━━━━━━━━━━━━ 14s 120ms/step - d_loss: -2.8976 - g_loss: -6.5240 - loss: 0.0000e+00
Epoch 19/20
118/118 ━━━━━━━━━━━━━━━━━━━━ 14s 120ms/step - d_loss: -2.7944 - g_loss: -6.6281 - loss: 0.0000e+00
Epoch 20/20
118/118 ━━━━━━━━━━━━━━━━━━━━ 14s 120ms/step - d_loss: -2.7175 - g_loss: -6.5900 - loss: 0.0000e+00
<keras.src.callbacks.history.History at 0x7fc763a8e950>
显示最后生成的图像:
from IPython.display import Image, display
display(Image("generated_img_0_19.png"))
display(Image("generated_img_1_19.png"))
display(Image("generated_img_2_19.png"))