条件GAN

作者: Sayak Paul
创建日期: 2021/07/13
最后修改: 2024/01/02
描述: 训练一个基于类别标签的GAN,以生成手写数字。

在Colab中查看 GitHub源代码

生成对抗网络(GAN)使我们能够从随机输入生成新颖的图像数据、视频数据或音频数据。通常,随机输入是从正态分布中采样的,然后经过一系列变换将其转换为某种合理的形式(图像、视频、音频等)。

然而,一个简单的DCGAN并不能让我们控制生成样本的外观(例如,类别)。例如,对于生成MNIST手写数字的GAN,一个简单的DCGAN无法让我们选择生成的数字类别。为了能够控制我们生成的内容,我们需要对GAN的输出进行_条件化_,基于语义输入,例如图像的类别。

在这个示例中,我们将构建一个条件GAN,能够根据给定的类别生成MNIST手写数字。这样的模型可以有各种有用的应用:

  • 假设您正在处理一个 不平衡的图像数据集, 并希望为偏斜类别收集更多示例以平衡数据集。数据收集本身可能是一个昂贵的过程。您可以训练一个条件GAN,并利用它为需要平衡的类别生成新颖的图像。
  • 由于生成器学习将生成样本与类别标签关联起来,它的表示也可以用于其他下游任务

以下是开发此示例所使用的参考文献:

如果您需要复习一下GAN,可以参考 这个资源中的“生成对抗网络”部分。

这个示例需要TensorFlow 2.5或更高版本,以及TensorFlow文档,可以使用以下命令进行安装:

!pip install -q git+https://github.com/tensorflow/docs

导入

import keras

from keras import layers
from keras import ops
from tensorflow_docs.vis import embed
import tensorflow as tf
import numpy as np
import imageio

常量和超参数

batch_size = 64
num_channels = 1
num_classes = 10
image_size = 28
latent_dim = 128

加载MNIST数据集并进行预处理

# 我们将使用训练集和测试集中所有可用的示例。
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
all_digits = np.concatenate([x_train, x_test])
all_labels = np.concatenate([y_train, y_test])

# 将像素值缩放到[0, 1]范围,为图像添加一个通道维度,并对标签进行独热编码。
all_digits = all_digits.astype("float32") / 255.0
all_digits = np.reshape(all_digits, (-1, 28, 28, 1))
all_labels = keras.utils.to_categorical(all_labels, 10)

# 创建tf.data.Dataset。
dataset = tf.data.Dataset.from_tensor_slices((all_digits, all_labels))
dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)

print(f"训练图像的形状: {all_digits.shape}")
print(f"训练标签的形状: {all_labels.shape}")
从 https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz 下载数据
 11490434/11490434 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step
训练图像的形状: (70000, 28, 28, 1)
训练标签的形状: (70000, 10)

计算生成器和判别器的输入通道数

在一个常规的(无条件)GAN中,我们首先从正态分布中抽样噪声(某个固定维度)。在我们的情况下,我们还需要考虑类别标签。我们需要将类别数添加到生成器(噪声输入)和判别器(生成图像输入)的输入通道中。

generator_in_channels = latent_dim + num_classes
discriminator_in_channels = num_channels + num_classes
print(generator_in_channels, discriminator_in_channels)
138 11

创建鉴别器和生成器

模型定义(discriminatorgeneratorConditionalGAN)来自于这个示例

# 创建鉴别器。
discriminator = keras.Sequential(
    [
        keras.layers.InputLayer((28, 28, discriminator_in_channels)),
        layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(negative_slope=0.2),
        layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(negative_slope=0.2),
        layers.GlobalMaxPooling2D(),
        layers.Dense(1),
    ],
    name="discriminator",
)

# 创建生成器。
generator = keras.Sequential(
    [
        keras.layers.InputLayer((generator_in_channels,)),
        # 我们希望生成 128 + num_classes 个系数,以重塑为
        # 7x7x(128 + num_classes) 的映射。
        layers.Dense(7 * 7 * generator_in_channels),
        layers.LeakyReLU(negative_slope=0.2),
        layers.Reshape((7, 7, generator_in_channels)),
        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
        layers.LeakyReLU(negative_slope=0.2),
        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
        layers.LeakyReLU(negative_slope=0.2),
        layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"),
    ],
    name="generator",
)

创建 ConditionalGAN 模型

class ConditionalGAN(keras.Model):
    def __init__(self, discriminator, generator, latent_dim):
        super().__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim
        self.seed_generator = keras.random.SeedGenerator(1337)
        self.gen_loss_tracker = keras.metrics.Mean(name="generator_loss")
        self.disc_loss_tracker = keras.metrics.Mean(name="discriminator_loss")

    @property
    def metrics(self):
        return [self.gen_loss_tracker, self.disc_loss_tracker]

    def compile(self, d_optimizer, g_optimizer, loss_fn):
        super().compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn

    def train_step(self, data):
        # 解压数据。
        real_images, one_hot_labels = data

        # 向标签中添加虚拟维度,以便可以与
        # 图像连接。这是给鉴别器用的。
        image_one_hot_labels = one_hot_labels[:, :, None, None]
        image_one_hot_labels = ops.repeat(
            image_one_hot_labels, repeats=[image_size * image_size]
        )
        image_one_hot_labels = ops.reshape(
            image_one_hot_labels, (-1, image_size, image_size, num_classes)
        )

        # 在潜在空间中随机采样点并连接标签。
        # 这是给生成器用的。
        batch_size = ops.shape(real_images)[0]
        random_latent_vectors = keras.random.normal(
            shape=(batch_size, self.latent_dim), seed=self.seed_generator
        )
        random_vector_labels = ops.concatenate(
            [random_latent_vectors, one_hot_labels], axis=1
        )

        # 解码噪声(根据标签引导)生成假图像。
        generated_images = self.generator(random_vector_labels)

        # 将它们与真实图像结合。注意这里我们将标签
        # 与这些图像连接。
        fake_image_and_labels = ops.concatenate(
            [generated_images, image_one_hot_labels], -1
        )
        real_image_and_labels = ops.concatenate([real_images, image_one_hot_labels], -1)
        combined_images = ops.concatenate(
            [fake_image_and_labels, real_image_and_labels], axis=0
        )

        # 汇总区分真实和假图像的标签。
        labels = ops.concatenate(
            [ops.ones((batch_size, 1)), ops.zeros((batch_size, 1))], axis=0
        )

        # 训练鉴别器。
        with tf.GradientTape() as tape:
            predictions = self.discriminator(combined_images)
            d_loss = self.loss_fn(labels, predictions)
        grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
        self.d_optimizer.apply_gradients(
            zip(grads, self.discriminator.trainable_weights)
        )

        # 在潜在空间中随机采样点。
        random_latent_vectors = keras.random.normal(
            shape=(batch_size, self.latent_dim), seed=self.seed_generator
        )
        random_vector_labels = ops.concatenate(
            [random_latent_vectors, one_hot_labels], axis=1
        )

        # 汇总表示“所有真实图像”的标签。
        misleading_labels = ops.zeros((batch_size, 1))

        # 训练生成器(注意我们不应该更新
        # 鉴别器的权重)!
        with tf.GradientTape() as tape:
            fake_images = self.generator(random_vector_labels)
            fake_image_and_labels = ops.concatenate(
                [fake_images, image_one_hot_labels], -1
            )
            predictions = self.discriminator(fake_image_and_labels)
            g_loss = self.loss_fn(misleading_labels, predictions)
        grads = tape.gradient(g_loss, self.generator.trainable_weights)
        self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))

        # 监控损失。
        self.gen_loss_tracker.update_state(g_loss)
        self.disc_loss_tracker.update_state(d_loss)
        return {
            "g_loss": self.gen_loss_tracker.result(),
            "d_loss": self.disc_loss_tracker.result(),
        }

训练条件GAN

cond_gan = ConditionalGAN(
    discriminator=discriminator, generator=generator, latent_dim=latent_dim
)
cond_gan.compile(
    d_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
    g_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
    loss_fn=keras.losses.BinaryCrossentropy(from_logits=True),
)

cond_gan.fit(dataset, epochs=20)
第1轮/20
   18/1094 ━━━━━━━━━━━━━━━━━━━━  10s 9ms/step - d_loss: 0.6321 - g_loss: 0.7887 

WARNING: 所有日志消息在调用absl::InitializeLog()前都写入STDERR
I0000 00:00:1704233262.157522    6737 device_compiler.h:186] 使用XLA编译集群! 该行最多在进程的生命周期中记录一次。

 1094/1094 ━━━━━━━━━━━━━━━━━━━━ 24s 14ms/step - d_loss: 0.4052 - g_loss: 1.5851 - discriminator_loss: 0.4390 - generator_loss: 1.4775
第2轮/20
 1094/1094 ━━━━━━━━━━━━━━━━━━━━ 10s 9ms/step - d_loss: 0.5116 - g_loss: 1.2740 - discriminator_loss: 0.4872 - generator_loss: 1.3330
第3轮/20
 1094/1094 ━━━━━━━━━━━━━━━━━━━━ 10s 9ms/step - d_loss: 0.3626 - g_loss: 1.6775 - discriminator_loss: 0.3252 - generator_loss: 1.8219
第4轮/20
 1094/1094 ━━━━━━━━━━━━━━━━━━━━ 10s 9ms/step - d_loss: 0.2248 - g_loss: 2.2898 - discriminator_loss: 0.3418 - generator_loss: 2.0042
第5轮/20
 1094/1094 ━━━━━━━━━━━━━━━━━━━━ 10s 9ms/step - d_loss: 0.6017 - g_loss: 1.0428 - discriminator_loss: 0.6076 - generator_loss: 1.0176
第6轮/20
 1094/1094 ━━━━━━━━━━━━━━━━━━━━ 10s 9ms/step - d_loss: 0.6395 - g_loss: 0.9258 - discriminator_loss: 0.6448 - generator_loss: 0.9134
第7轮/20
 1094/1094 ━━━━━━━━━━━━━━━━━━━━ 10s 9ms/step - d_loss: 0.6402 - g_loss: 0.8914 - discriminator_loss: 0.6458 - generator_loss: 0.8773
第8轮/20
 1094/1094 ━━━━━━━━━━━━━━━━━━━━ 10s 9ms/step - d_loss: 0.6549 - g_loss: 0.8440 - discriminator_loss: 0.6555 - generator_loss: 0.8364
第9轮/20
 1094/1094 ━━━━━━━━━━━━━━━━━━━━ 10s 9ms/step - d_loss: 0.6603 - g_loss: 0.8316 - discriminator_loss: 0.6606 - generator_loss: 0.8241
第10轮/20
 1094/1094 ━━━━━━━━━━━━━━━━━━━━ 10s 9ms/step - d_loss: 0.6594 - g_loss: 0.8169 - discriminator_loss: 0.6605 - generator_loss: 0.8218
第11轮/20
 1094/1094 ━━━━━━━━━━━━━━━━━━━━ 10s 9ms/step - d_loss: 0.6719 - g_loss: 0.7979 - discriminator_loss: 0.6649 - generator_loss: 0.8096
第12轮/20
 1094/1094 ━━━━━━━━━━━━━━━━━━━━ 10s 9ms/step - d_loss: 0.6641 - g_loss: 0.7992 - discriminator_loss: 0.6621 - generator_loss: 0.7953
第13轮/20
 1094/1094 ━━━━━━━━━━━━━━━━━━━━ 10s 9ms/step - d_loss: 0.6657 - g_loss: 0.7979 - discriminator_loss: 0.6624 - generator_loss: 0.7924
第14轮/20
 1094/1094 ━━━━━━━━━━━━━━━━━━━━ 10s 9ms/step - d_loss: 0.6586 - g_loss: 0.8220 - discriminator_loss: 0.6566 - generator_loss: 0.8174
第15轮/20
 1094/1094 ━━━━━━━━━━━━━━━━━━━━ 10s 9ms/step - d_loss: 0.6646 - g_loss: 0.7916 - discriminator_loss: 0.6578 - generator_loss: 0.7973
第16轮/20
 1094/1094 ━━━━━━━━━━━━━━━━━━━━ 10s 9ms/step - d_loss: 0.6624 - g_loss: 0.7911 - discriminator_loss: 0.6587 - generator_loss: 0.7966
第17轮/20
 1094/1094 ━━━━━━━━━━━━━━━━━━━━ 10s 9ms/step - d_loss: 0.6586 - g_loss: 0.8060 - discriminator_loss: 0.6550 - generator_loss: 0.7997
第18轮/20
 1094/1094 ━━━━━━━━━━━━━━━━━━━━ 10s 9ms/step - d_loss: 0.6526 - g_loss: 0.7946 - discriminator_loss: 0.6523 - generator_loss: 0.7948
第19轮/20
 1094/1094 ━━━━━━━━━━━━━━━━━━━━ 10s 9ms/step - d_loss: 0.6525 - g_loss: 0.8039 - discriminator_loss: 0.6497 - generator_loss: 0.8066
第20轮/20
 1094/1094 ━━━━━━━━━━━━━━━━━━━━ 10s 9ms/step - d_loss: 0.6480 - g_loss: 0.8005 - discriminator_loss: 0.6469 - generator_loss: 0.8022

<keras.src.callbacks.history.History at 0x7f541a1b5f90>

使用训练好的生成器进行类之间插值

# 首先从我们的条件GAN中提取训练好的生成器。
trained_gen = cond_gan.generator

# 选择在插值过程中生成的中间图像数量 + 2(起始图像和最后图像)。
num_interpolation = 9  # @param {type:"integer"}

# 为插值样本生成噪声。
interpolation_noise = keras.random.normal(shape=(1, latent_dim))
interpolation_noise = ops.repeat(interpolation_noise, repeats=num_interpolation)
interpolation_noise = ops.reshape(interpolation_noise, (num_interpolation, latent_dim))


def interpolate_class(first_number, second_number):
    # 将起始和结束标签转换为独热编码向量。
    first_label = keras.utils.to_categorical([first_number], num_classes)
    second_label = keras.utils.to_categorical([second_number], num_classes)
    first_label = ops.cast(first_label, "float32")
    second_label = ops.cast(second_label, "float32")

    # 计算两个标签之间的插值向量。
    percent_second_label = ops.linspace(0, 1, num_interpolation)[:, None]
    percent_second_label = ops.cast(percent_second_label, "float32")
    interpolation_labels = (
        first_label * (1 - percent_second_label) + second_label * percent_second_label
    )

    # 将噪声和标签结合,并使用生成器进行推理。
    noise_and_labels = ops.concatenate([interpolation_noise, interpolation_labels], 1)
    fake = trained_gen.predict(noise_and_labels)
    return fake


start_class = 2  # @param {type:"slider", min:0, max:9, step:1}
end_class = 6  # @param {type:"slider", min:0, max:9, step:1}

fake_images = interpolate_class(start_class, end_class)
 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 427ms/step

在这里,我们首先从正态分布中采样噪声,然后重复进行num_interpolation次,并相应地重塑结果。 然后,我们以num_interpolation均匀分布,标签身份以某种比例出现。

fake_images *= 255.0
converted_images = fake_images.astype(np.uint8)
converted_images = ops.image.resize(converted_images, (96, 96)).numpy().astype(np.uint8)
imageio.mimsave("animation.gif", converted_images[:, :, :, 0], fps=1)
embed.embed_file("animation.gif")

我们可以通过像WGAN-GP这样的方案进一步提高该模型的性能。条件生成在许多现代图像生成架构中也被广泛使用,如VQ-GANsDALL-E等。

您可以使用托管在Hugging Face Hub的训练模型,并在Hugging Face Spaces上尝试演示。