作者: Sayak Paul
创建日期: 2021/07/13
最后修改: 2024/01/02
描述: 训练一个基于类别标签的GAN,以生成手写数字。
生成对抗网络(GAN)使我们能够从随机输入生成新颖的图像数据、视频数据或音频数据。通常,随机输入是从正态分布中采样的,然后经过一系列变换将其转换为某种合理的形式(图像、视频、音频等)。
然而,一个简单的DCGAN并不能让我们控制生成样本的外观(例如,类别)。例如,对于生成MNIST手写数字的GAN,一个简单的DCGAN无法让我们选择生成的数字类别。为了能够控制我们生成的内容,我们需要对GAN的输出进行_条件化_,基于语义输入,例如图像的类别。
在这个示例中,我们将构建一个条件GAN,能够根据给定的类别生成MNIST手写数字。这样的模型可以有各种有用的应用:
以下是开发此示例所使用的参考文献:
如果您需要复习一下GAN,可以参考 这个资源中的“生成对抗网络”部分。
这个示例需要TensorFlow 2.5或更高版本,以及TensorFlow文档,可以使用以下命令进行安装:
!pip install -q git+https://github.com/tensorflow/docs
import keras
from keras import layers
from keras import ops
from tensorflow_docs.vis import embed
import tensorflow as tf
import numpy as np
import imageio
batch_size = 64
num_channels = 1
num_classes = 10
image_size = 28
latent_dim = 128
# 我们将使用训练集和测试集中所有可用的示例。
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
all_digits = np.concatenate([x_train, x_test])
all_labels = np.concatenate([y_train, y_test])
# 将像素值缩放到[0, 1]范围,为图像添加一个通道维度,并对标签进行独热编码。
all_digits = all_digits.astype("float32") / 255.0
all_digits = np.reshape(all_digits, (-1, 28, 28, 1))
all_labels = keras.utils.to_categorical(all_labels, 10)
# 创建tf.data.Dataset。
dataset = tf.data.Dataset.from_tensor_slices((all_digits, all_labels))
dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)
print(f"训练图像的形状: {all_digits.shape}")
print(f"训练标签的形状: {all_labels.shape}")
从 https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz 下载数据
11490434/11490434 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step
训练图像的形状: (70000, 28, 28, 1)
训练标签的形状: (70000, 10)
在一个常规的(无条件)GAN中,我们首先从正态分布中抽样噪声(某个固定维度)。在我们的情况下,我们还需要考虑类别标签。我们需要将类别数添加到生成器(噪声输入)和判别器(生成图像输入)的输入通道中。
generator_in_channels = latent_dim + num_classes
discriminator_in_channels = num_channels + num_classes
print(generator_in_channels, discriminator_in_channels)
138 11
模型定义(discriminator
、generator
和ConditionalGAN
)来自于这个示例。
# 创建鉴别器。
discriminator = keras.Sequential(
[
keras.layers.InputLayer((28, 28, discriminator_in_channels)),
layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"),
layers.LeakyReLU(negative_slope=0.2),
layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"),
layers.LeakyReLU(negative_slope=0.2),
layers.GlobalMaxPooling2D(),
layers.Dense(1),
],
name="discriminator",
)
# 创建生成器。
generator = keras.Sequential(
[
keras.layers.InputLayer((generator_in_channels,)),
# 我们希望生成 128 + num_classes 个系数,以重塑为
# 7x7x(128 + num_classes) 的映射。
layers.Dense(7 * 7 * generator_in_channels),
layers.LeakyReLU(negative_slope=0.2),
layers.Reshape((7, 7, generator_in_channels)),
layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
layers.LeakyReLU(negative_slope=0.2),
layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
layers.LeakyReLU(negative_slope=0.2),
layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"),
],
name="generator",
)
ConditionalGAN
模型class ConditionalGAN(keras.Model):
def __init__(self, discriminator, generator, latent_dim):
super().__init__()
self.discriminator = discriminator
self.generator = generator
self.latent_dim = latent_dim
self.seed_generator = keras.random.SeedGenerator(1337)
self.gen_loss_tracker = keras.metrics.Mean(name="generator_loss")
self.disc_loss_tracker = keras.metrics.Mean(name="discriminator_loss")
@property
def metrics(self):
return [self.gen_loss_tracker, self.disc_loss_tracker]
def compile(self, d_optimizer, g_optimizer, loss_fn):
super().compile()
self.d_optimizer = d_optimizer
self.g_optimizer = g_optimizer
self.loss_fn = loss_fn
def train_step(self, data):
# 解压数据。
real_images, one_hot_labels = data
# 向标签中添加虚拟维度,以便可以与
# 图像连接。这是给鉴别器用的。
image_one_hot_labels = one_hot_labels[:, :, None, None]
image_one_hot_labels = ops.repeat(
image_one_hot_labels, repeats=[image_size * image_size]
)
image_one_hot_labels = ops.reshape(
image_one_hot_labels, (-1, image_size, image_size, num_classes)
)
# 在潜在空间中随机采样点并连接标签。
# 这是给生成器用的。
batch_size = ops.shape(real_images)[0]
random_latent_vectors = keras.random.normal(
shape=(batch_size, self.latent_dim), seed=self.seed_generator
)
random_vector_labels = ops.concatenate(
[random_latent_vectors, one_hot_labels], axis=1
)
# 解码噪声(根据标签引导)生成假图像。
generated_images = self.generator(random_vector_labels)
# 将它们与真实图像结合。注意这里我们将标签
# 与这些图像连接。
fake_image_and_labels = ops.concatenate(
[generated_images, image_one_hot_labels], -1
)
real_image_and_labels = ops.concatenate([real_images, image_one_hot_labels], -1)
combined_images = ops.concatenate(
[fake_image_and_labels, real_image_and_labels], axis=0
)
# 汇总区分真实和假图像的标签。
labels = ops.concatenate(
[ops.ones((batch_size, 1)), ops.zeros((batch_size, 1))], axis=0
)
# 训练鉴别器。
with tf.GradientTape() as tape:
predictions = self.discriminator(combined_images)
d_loss = self.loss_fn(labels, predictions)
grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
self.d_optimizer.apply_gradients(
zip(grads, self.discriminator.trainable_weights)
)
# 在潜在空间中随机采样点。
random_latent_vectors = keras.random.normal(
shape=(batch_size, self.latent_dim), seed=self.seed_generator
)
random_vector_labels = ops.concatenate(
[random_latent_vectors, one_hot_labels], axis=1
)
# 汇总表示“所有真实图像”的标签。
misleading_labels = ops.zeros((batch_size, 1))
# 训练生成器(注意我们不应该更新
# 鉴别器的权重)!
with tf.GradientTape() as tape:
fake_images = self.generator(random_vector_labels)
fake_image_and_labels = ops.concatenate(
[fake_images, image_one_hot_labels], -1
)
predictions = self.discriminator(fake_image_and_labels)
g_loss = self.loss_fn(misleading_labels, predictions)
grads = tape.gradient(g_loss, self.generator.trainable_weights)
self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))
# 监控损失。
self.gen_loss_tracker.update_state(g_loss)
self.disc_loss_tracker.update_state(d_loss)
return {
"g_loss": self.gen_loss_tracker.result(),
"d_loss": self.disc_loss_tracker.result(),
}
cond_gan = ConditionalGAN(
discriminator=discriminator, generator=generator, latent_dim=latent_dim
)
cond_gan.compile(
d_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
g_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
loss_fn=keras.losses.BinaryCrossentropy(from_logits=True),
)
cond_gan.fit(dataset, epochs=20)
第1轮/20
18/1094 ━━━━━━━━━━━━━━━━━━━━ 10s 9ms/step - d_loss: 0.6321 - g_loss: 0.7887
WARNING: 所有日志消息在调用absl::InitializeLog()前都写入STDERR
I0000 00:00:1704233262.157522 6737 device_compiler.h:186] 使用XLA编译集群! 该行最多在进程的生命周期中记录一次。
1094/1094 ━━━━━━━━━━━━━━━━━━━━ 24s 14ms/step - d_loss: 0.4052 - g_loss: 1.5851 - discriminator_loss: 0.4390 - generator_loss: 1.4775
第2轮/20
1094/1094 ━━━━━━━━━━━━━━━━━━━━ 10s 9ms/step - d_loss: 0.5116 - g_loss: 1.2740 - discriminator_loss: 0.4872 - generator_loss: 1.3330
第3轮/20
1094/1094 ━━━━━━━━━━━━━━━━━━━━ 10s 9ms/step - d_loss: 0.3626 - g_loss: 1.6775 - discriminator_loss: 0.3252 - generator_loss: 1.8219
第4轮/20
1094/1094 ━━━━━━━━━━━━━━━━━━━━ 10s 9ms/step - d_loss: 0.2248 - g_loss: 2.2898 - discriminator_loss: 0.3418 - generator_loss: 2.0042
第5轮/20
1094/1094 ━━━━━━━━━━━━━━━━━━━━ 10s 9ms/step - d_loss: 0.6017 - g_loss: 1.0428 - discriminator_loss: 0.6076 - generator_loss: 1.0176
第6轮/20
1094/1094 ━━━━━━━━━━━━━━━━━━━━ 10s 9ms/step - d_loss: 0.6395 - g_loss: 0.9258 - discriminator_loss: 0.6448 - generator_loss: 0.9134
第7轮/20
1094/1094 ━━━━━━━━━━━━━━━━━━━━ 10s 9ms/step - d_loss: 0.6402 - g_loss: 0.8914 - discriminator_loss: 0.6458 - generator_loss: 0.8773
第8轮/20
1094/1094 ━━━━━━━━━━━━━━━━━━━━ 10s 9ms/step - d_loss: 0.6549 - g_loss: 0.8440 - discriminator_loss: 0.6555 - generator_loss: 0.8364
第9轮/20
1094/1094 ━━━━━━━━━━━━━━━━━━━━ 10s 9ms/step - d_loss: 0.6603 - g_loss: 0.8316 - discriminator_loss: 0.6606 - generator_loss: 0.8241
第10轮/20
1094/1094 ━━━━━━━━━━━━━━━━━━━━ 10s 9ms/step - d_loss: 0.6594 - g_loss: 0.8169 - discriminator_loss: 0.6605 - generator_loss: 0.8218
第11轮/20
1094/1094 ━━━━━━━━━━━━━━━━━━━━ 10s 9ms/step - d_loss: 0.6719 - g_loss: 0.7979 - discriminator_loss: 0.6649 - generator_loss: 0.8096
第12轮/20
1094/1094 ━━━━━━━━━━━━━━━━━━━━ 10s 9ms/step - d_loss: 0.6641 - g_loss: 0.7992 - discriminator_loss: 0.6621 - generator_loss: 0.7953
第13轮/20
1094/1094 ━━━━━━━━━━━━━━━━━━━━ 10s 9ms/step - d_loss: 0.6657 - g_loss: 0.7979 - discriminator_loss: 0.6624 - generator_loss: 0.7924
第14轮/20
1094/1094 ━━━━━━━━━━━━━━━━━━━━ 10s 9ms/step - d_loss: 0.6586 - g_loss: 0.8220 - discriminator_loss: 0.6566 - generator_loss: 0.8174
第15轮/20
1094/1094 ━━━━━━━━━━━━━━━━━━━━ 10s 9ms/step - d_loss: 0.6646 - g_loss: 0.7916 - discriminator_loss: 0.6578 - generator_loss: 0.7973
第16轮/20
1094/1094 ━━━━━━━━━━━━━━━━━━━━ 10s 9ms/step - d_loss: 0.6624 - g_loss: 0.7911 - discriminator_loss: 0.6587 - generator_loss: 0.7966
第17轮/20
1094/1094 ━━━━━━━━━━━━━━━━━━━━ 10s 9ms/step - d_loss: 0.6586 - g_loss: 0.8060 - discriminator_loss: 0.6550 - generator_loss: 0.7997
第18轮/20
1094/1094 ━━━━━━━━━━━━━━━━━━━━ 10s 9ms/step - d_loss: 0.6526 - g_loss: 0.7946 - discriminator_loss: 0.6523 - generator_loss: 0.7948
第19轮/20
1094/1094 ━━━━━━━━━━━━━━━━━━━━ 10s 9ms/step - d_loss: 0.6525 - g_loss: 0.8039 - discriminator_loss: 0.6497 - generator_loss: 0.8066
第20轮/20
1094/1094 ━━━━━━━━━━━━━━━━━━━━ 10s 9ms/step - d_loss: 0.6480 - g_loss: 0.8005 - discriminator_loss: 0.6469 - generator_loss: 0.8022
<keras.src.callbacks.history.History at 0x7f541a1b5f90>
# 首先从我们的条件GAN中提取训练好的生成器。
trained_gen = cond_gan.generator
# 选择在插值过程中生成的中间图像数量 + 2(起始图像和最后图像)。
num_interpolation = 9 # @param {type:"integer"}
# 为插值样本生成噪声。
interpolation_noise = keras.random.normal(shape=(1, latent_dim))
interpolation_noise = ops.repeat(interpolation_noise, repeats=num_interpolation)
interpolation_noise = ops.reshape(interpolation_noise, (num_interpolation, latent_dim))
def interpolate_class(first_number, second_number):
# 将起始和结束标签转换为独热编码向量。
first_label = keras.utils.to_categorical([first_number], num_classes)
second_label = keras.utils.to_categorical([second_number], num_classes)
first_label = ops.cast(first_label, "float32")
second_label = ops.cast(second_label, "float32")
# 计算两个标签之间的插值向量。
percent_second_label = ops.linspace(0, 1, num_interpolation)[:, None]
percent_second_label = ops.cast(percent_second_label, "float32")
interpolation_labels = (
first_label * (1 - percent_second_label) + second_label * percent_second_label
)
# 将噪声和标签结合,并使用生成器进行推理。
noise_and_labels = ops.concatenate([interpolation_noise, interpolation_labels], 1)
fake = trained_gen.predict(noise_and_labels)
return fake
start_class = 2 # @param {type:"slider", min:0, max:9, step:1}
end_class = 6 # @param {type:"slider", min:0, max:9, step:1}
fake_images = interpolate_class(start_class, end_class)
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 427ms/step
在这里,我们首先从正态分布中采样噪声,然后重复进行num_interpolation
次,并相应地重塑结果。
然后,我们以num_interpolation
均匀分布,标签身份以某种比例出现。
fake_images *= 255.0
converted_images = fake_images.astype(np.uint8)
converted_images = ops.image.resize(converted_images, (96, 96)).numpy().astype(np.uint8)
imageio.mimsave("animation.gif", converted_images[:, :, :, 0], fps=1)
embed.embed_file("animation.gif")
我们可以通过像WGAN-GP这样的方案进一步提高该模型的性能。条件生成在许多现代图像生成架构中也被广泛使用,如VQ-GANs、DALL-E等。
您可以使用托管在Hugging Face Hub的训练模型,并在Hugging Face Spaces上尝试演示。