作者: fchollet
创建日期: 2019/04/29
最后修改日期: 2023/12/21
描述: 使用 fit()
并重写 train_step
在 CelebA 图像上训练的简单 DCGAN。
import keras
import tensorflow as tf
from keras import layers
from keras import ops
import matplotlib.pyplot as plt
import os
import gdown
from zipfile import ZipFile
我们将使用来自 CelebA 数据集的面部图像,并将其调整为 64x64。
os.makedirs("celeba_gan")
url = "https://drive.google.com/uc?id=1O7m1010EJjLE5QxLZiM9Fpjs7Oj6e684"
output = "celeba_gan/data.zip"
gdown.download(url, output, quiet=True)
with ZipFile("celeba_gan/data.zip", "r") as zipobj:
zipobj.extractall("celeba_gan")
从我们的文件夹创建数据集,并将图像重缩放到 [0-1] 范围:
dataset = keras.utils.image_dataset_from_directory(
"celeba_gan", label_mode=None, image_size=(64, 64), batch_size=32
)
dataset = dataset.map(lambda x: x / 255.0)
找到 202599 个文件。
让我们显示一张样本图像:
for x in dataset:
plt.axis("off")
plt.imshow((x.numpy() * 255).astype("int32")[0])
break
它将 64x64 图像映射到一个二元分类分数。
discriminator = keras.Sequential(
[
keras.Input(shape=(64, 64, 3)),
layers.Conv2D(64, kernel_size=4, strides=2, padding="same"),
layers.LeakyReLU(negative_slope=0.2),
layers.Conv2D(128, kernel_size=4, strides=2, padding="same"),
layers.LeakyReLU(negative_slope=0.2),
layers.Conv2D(128, kernel_size=4, strides=2, padding="same"),
layers.LeakyReLU(negative_slope=0.2),
layers.Flatten(),
layers.Dropout(0.2),
layers.Dense(1, activation="sigmoid"),
],
name="discriminator",
)
discriminator.summary()
模型: "discriminator"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓ ┃ 层 (类型) ┃ 输出形状 ┃ 参数 # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩ │ conv2d (Conv2D) │ (None, 32, 32, 64) │ 3,136 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ leaky_re_lu (LeakyReLU) │ (None, 32, 32, 64) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv2d_1 (Conv2D) │ (None, 16, 16, 128) │ 131,200 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ leaky_re_lu_1 (LeakyReLU) │ (None, 16, 16, 128) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv2d_2 (Conv2D) │ (None, 8, 8, 128) │ 262,272 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ leaky_re_lu_2 (LeakyReLU) │ (None, 8, 8, 128) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ flatten (Flatten) │ (None, 8192) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ dropout (Dropout) │ (None, 8192) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ dense (Dense) │ (None, 1) │ 8,193 │ └─────────────────────────────────┴───────────────────────────┴────────────┘
总参数量: 404,801 (1.54 MB)
可训练参数: 404,801 (1.54 MB)
非可训练参数: 0 (0.00 B)
它与判别器相镜像,将 Conv2D
层替换为 Conv2DTranspose
层。
latent_dim = 128
generator = keras.Sequential(
[
keras.Input(shape=(latent_dim,)),
layers.Dense(8 * 8 * 128),
layers.Reshape((8, 8, 128)),
layers.Conv2DTranspose(128, kernel_size=4, strides=2, padding="same"),
layers.LeakyReLU(negative_slope=0.2),
layers.Conv2DTranspose(256, kernel_size=4, strides=2, padding="same"),
layers.LeakyReLU(negative_slope=0.2),
layers.Conv2DTranspose(512, kernel_size=4, strides=2, padding="same"),
layers.LeakyReLU(negative_slope=0.2),
layers.Conv2D(3, kernel_size=5, padding="same", activation="sigmoid"),
],
name="generator",
)
generator.summary()
模型: "generator"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩ │ dense_1 (Dense) │ (None, 8192) │ 1,056,768 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ reshape (Reshape) │ (None, 8, 8, 128) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv2d_transpose │ (None, 16, 16, 128) │ 262,272 │ │ (Conv2DTranspose) │ │ │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ leaky_re_lu_3 (LeakyReLU) │ (None, 16, 16, 128) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv2d_transpose_1 │ (None, 32, 32, 256) │ 524,544 │ │ (Conv2DTranspose) │ │ │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ leaky_re_lu_4 (LeakyReLU) │ (None, 32, 32, 256) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv2d_transpose_2 │ (None, 64, 64, 512) │ 2,097,664 │ │ (Conv2DTranspose) │ │ │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ leaky_re_lu_5 (LeakyReLU) │ (None, 64, 64, 512) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv2d_3 (Conv2D) │ (无, 64, 64, 3) │ 38,403 │ └─────────────────────────────────┴───────────────────────────┴────────────┘
总参数: 3,979,651 (15.18 MB)
可训练参数: 3,979,651 (15.18 MB)
不可训练参数: 0 (0.00 B)
train_step
class GAN(keras.Model):
def __init__(self, discriminator, generator, latent_dim):
super().__init__()
self.discriminator = discriminator
self.generator = generator
self.latent_dim = latent_dim
self.seed_generator = keras.random.SeedGenerator(1337)
def compile(self, d_optimizer, g_optimizer, loss_fn):
super().compile()
self.d_optimizer = d_optimizer
self.g_optimizer = g_optimizer
self.loss_fn = loss_fn
self.d_loss_metric = keras.metrics.Mean(name="d_loss")
self.g_loss_metric = keras.metrics.Mean(name="g_loss")
@property
def metrics(self):
return [self.d_loss_metric, self.g_loss_metric]
def train_step(self, real_images):
# 从潜在空间中随机采样
batch_size = ops.shape(real_images)[0]
random_latent_vectors = keras.random.normal(
shape=(batch_size, self.latent_dim), seed=self.seed_generator
)
# 解码为假图像
generated_images = self.generator(random_latent_vectors)
# 与真实图像结合
combined_images = ops.concatenate([generated_images, real_images], axis=0)
# 组装区分真实和假图像的标签
labels = ops.concatenate(
[ops.ones((batch_size, 1)), ops.zeros((batch_size, 1))], axis=0
)
# 向标签中添加随机噪声 - 重要技巧!
labels += 0.05 * tf.random.uniform(tf.shape(labels))
# 训练判别器
with tf.GradientTape() as tape:
predictions = self.discriminator(combined_images)
d_loss = self.loss_fn(labels, predictions)
grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
self.d_optimizer.apply_gradients(
zip(grads, self.discriminator.trainable_weights)
)
# 从潜在空间中随机采样
random_latent_vectors = keras.random.normal(
shape=(batch_size, self.latent_dim), seed=self.seed_generator
)
# 组装标签,表示“所有图像都是真实的”
misleading_labels = ops.zeros((batch_size, 1))
# 训练生成器(注意我们*不*应更新判别器的权重)!
with tf.GradientTape() as tape:
predictions = self.discriminator(self.generator(random_latent_vectors))
g_loss = self.loss_fn(misleading_labels, predictions)
grads = tape.gradient(g_loss, self.generator.trainable_weights)
self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))
# 更新指标
self.d_loss_metric.update_state(d_loss)
self.g_loss_metric.update_state(g_loss)
return {
"d_loss": self.d_loss_metric.result(),
"g_loss": self.g_loss_metric.result(),
}
class GANMonitor(keras.callbacks.Callback):
def __init__(self, num_img=3, latent_dim=128):
self.num_img = num_img
self.latent_dim = latent_dim
self.seed_generator = keras.random.SeedGenerator(42)
def on_epoch_end(self, epoch, logs=None):
random_latent_vectors = keras.random.normal(
shape=(self.num_img, self.latent_dim), seed=self.seed_generator
)
generated_images = self.model.generator(random_latent_vectors)
generated_images *= 255
generated_images.numpy()
for i in range(self.num_img):
img = keras.utils.array_to_img(generated_images[i])
img.save("generated_img_%03d_%d.png" % (epoch, i))
epochs = 1 # 实际中使用 ~100 轮
gan = GAN(discriminator=discriminator, generator=generator, latent_dim=latent_dim)
gan.compile(
d_optimizer=keras.optimizers.Adam(learning_rate=0.0001),
g_optimizer=keras.optimizers.Adam(learning_rate=0.0001),
loss_fn=keras.losses.BinaryCrossentropy(),
)
gan.fit(
dataset, epochs=epochs, callbacks=[GANMonitor(num_img=10, latent_dim=latent_dim)]
)
2/6332 ━━━━━━━━━━━━━━━━━━━━ 9:54 94ms/step - d_loss: 0.6792 - g_loss: 0.7880
警告:调用 absl::InitializeLog() 之前的所有日志消息都写入 STDERR
I0000 00:00:1704214667.959762 1319 device_compiler.h:186] 使用 XLA 编译的集群!该行在进程生命周期内最多记录一次。
6332/6332 ━━━━━━━━━━━━━━━━━━━━ 557s 84ms/step - d_loss: 0.5616 - g_loss: 1.4099
<keras.src.callbacks.history.History at 0x7f251d32bc40>
大约在第 30 轮生成的一些最后图像 (结果在此之后持续改善):