代码示例 / 生成式深度学习 / 数据高效的 GAN 结合自适应鉴别器增强

数据高效的 GAN 结合自适应鉴别器增强

作者: András Béres
创建日期: 2021/10/28
最后修改: 2021/10/28
描述: 使用 Caltech Birds 数据集从有限数据中生成图像。

在 Colab 中查看 GitHub 源码


介绍

GANs

生成对抗网络 (GANs) 是一种流行的生成深度学习模型,常用于图像生成。它们由一对对抗的神经网络组成,称为鉴别器和生成器。鉴别器的任务是区分真实图像和生成(虚假)图像,而生成器则试图通过生成越来越真实的图像来欺骗鉴别器。然而,如果生成器太容易或太难以欺骗,可能会导致无法为生成器提供有用的学习信号,因此训练 GAN 通常被认为是一项困难的任务。

GANs 的数据增强

数据增强是一种在深度学习中流行的技术,是随机应用保持语义的转换到输入数据的过程,以生成多个现实版本,从而有效地增加可用的训练数据量。最简单的例子是左右翻转图像,这样可以在保留内容的同时生成第二个独特的训练样本。数据增强通常在监督学习中使用,以防止过拟合并增强泛化能力。

StyleGAN2-ADA 的作者表明,在 GAN 中,鉴别器过拟合可能是一个问题,特别是在只有少量训练数据可用时。他们提出了自适应鉴别器增强以缓解这个问题。

然而,将数据增强应用于 GAN 并不是简单的事情。由于生成器是使用鉴别器的梯度更新的,因此如果生成的图像被增强,增强管道必须是可微分的,并且必须与 GPU 兼容以提高计算效率。幸运的是,Keras 图像增强层 满足这两个要求,因此非常适合这个任务。

可逆的数据增强

在生成模型中使用数据增强时,可能的困难是关于"泄漏增强"(第 2.2 节)的问题,即模型生成的图像已经被增强。这将意味着它无法将增强与基础数据分布分开,这可能是由于使用了不可逆的数据转换造成的。例如,如果以相等的概率进行 0、90、180 或 270 度的旋转,则无法推断图像的原始方向,并且这一信息被破坏。

使数据增强可逆的一个简单技巧是仅在一定概率下应用它们。这样,图像的原始版本会更常见,并且数据分布可以被推断。通过适当地选择这个概率,可以有效地正则化鉴别器,而不使增强泄漏。


设置

import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds

from tensorflow import keras
from tensorflow.keras import layers

超参数

# 数据
num_epochs = 10  # 训练 400 个周期以获得良好的结果
image_size = 64
# Kernel Inception Distance 测量的分辨率,请参见相关部分
kid_image_size = 75
padding = 0.25
dataset_name = "caltech_birds2011"

# 自适应鉴别器增强
max_translation = 0.125
max_rotation = 0.125
max_zoom = 0.25
target_accuracy = 0.85
integration_steps = 1000

# 架构
noise_size = 64
depth = 4
width = 128
leaky_relu_slope = 0.2
dropout_rate = 0.4

# 优化
batch_size = 128
learning_rate = 2e-4
beta_1 = 0.5  # 不使用默认值 0.9 是重要的
ema = 0.99

数据管道

在本例中,我们将使用 Caltech Birds (2011) 数据集生成鸟类图像,这是一个多样的自然数据集,包含不到 6000 张用于训练的图像。在处理如此少量的数据时,必须特别小心,以尽可能保持高数据质量。在本例中,我们使用提供的鸟类边界框,在尽可能保留其宽高比的情况下,对其进行方形裁剪。

def round_to_int(float_value):
    return tf.cast(tf.math.round(float_value), dtype=tf.int32)


def preprocess_image(data):
    # 反归一化边界框坐标
    height = tf.cast(tf.shape(data["image"])[0], dtype=tf.float32)
    width = tf.cast(tf.shape(data["image"])[1], dtype=tf.float32)
    bounding_box = data["bbox"] * tf.stack([height, width, height, width])

    # 计算中心和较长边的长度,添加填充
    target_center_y = 0.5 * (bounding_box[0] + bounding_box[2])
    target_center_x = 0.5 * (bounding_box[1] + bounding_box[3])
    target_size = tf.maximum(
        (1.0 + padding) * (bounding_box[2] - bounding_box[0]),
        (1.0 + padding) * (bounding_box[3] - bounding_box[1]),
    )

    # 修改裁剪大小以适应图像
    target_height = tf.reduce_min(
        [target_size, 2.0 * target_center_y, 2.0 * (height - target_center_y)]
    )
    target_width = tf.reduce_min(
        [target_size, 2.0 * target_center_x, 2.0 * (width - target_center_x)]
    )

    # 裁剪图像
    image = tf.image.crop_to_bounding_box(
        data["image"],
        offset_height=round_to_int(target_center_y - 0.5 * target_height),
        offset_width=round_to_int(target_center_x - 0.5 * target_width),
        target_height=round_to_int(target_height),
        target_width=round_to_int(target_width),
    )

    # 调整大小并剪裁
    # 对于图像下采样,面积插值是首选方法
    image = tf.image.resize(
        image, size=[image_size, image_size], method=tf.image.ResizeMethod.AREA
    )
    return tf.clip_by_value(image / 255.0, 0.0, 1.0)


def prepare_dataset(split):
    # 验证数据集也会被打乱,因为数据顺序很重要
    # 对于KID计算
    return (
        tfds.load(dataset_name, split=split, shuffle_files=True)
        .map(preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
        .cache()
        .shuffle(10 * batch_size)
        .batch(batch_size, drop_remainder=True)
        .prefetch(buffer_size=tf.data.AUTOTUNE)
    )


train_dataset = prepare_dataset("train")
val_dataset = prepare_dataset("test")

在预处理后,训练图像看起来如下: birds dataset


核心Inception距离

核心Inception距离 (KID) 被提出作为一种 替代流行的 Frechet Inception距离 (FID) 的度量,用于衡量图像生成质量。 这两个度量在 InceptionV3 网络的表示空间中,衡量生成分布与训练分布的差异,网络是在 ImageNet进行预训练的。

根据论文,KID被提出是因为FID没有无偏估计器,即在较少的图像上测量时,它的期望值更高。KID更适合小型数据集,因为它的期望值不依赖于测量样本的数量。根据我的经验,它在计算上也更轻,数值上更稳定,实现起来也更简单,因为它可以按批次进行估计。

在这个例子中,图像在Inception网络的最小可能分辨率下进行评估(75x75而不是299x299),并且该度量仅在验证集上测量,以提高计算效率。

class KID(keras.metrics.Metric):
    def __init__(self, name="kid", **kwargs):
        super().__init__(name=name, **kwargs)

        # KID 是按批次估计的,并在批次之间平均
        self.kid_tracker = keras.metrics.Mean()

        # 使用预训练的 InceptionV3,省略其分类层
        # 将像素值转换到 0-255 范围,然后使用与预训练期间相同的
        # 预处理方法
        self.encoder = keras.Sequential(
            [
                layers.InputLayer(input_shape=(image_size, image_size, 3)),
                layers.Rescaling(255.0),
                layers.Resizing(height=kid_image_size, width=kid_image_size),
                layers.Lambda(keras.applications.inception_v3.preprocess_input),
                keras.applications.InceptionV3(
                    include_top=False,
                    input_shape=(kid_image_size, kid_image_size, 3),
                    weights="imagenet",
                ),
                layers.GlobalAveragePooling2D(),
            ],
            name="inception_encoder",
        )

    def polynomial_kernel(self, features_1, features_2):
        feature_dimensions = tf.cast(tf.shape(features_1)[1], dtype=tf.float32)
        return (features_1 @ tf.transpose(features_2) / feature_dimensions + 1.0) ** 3.0

    def update_state(self, real_images, generated_images, sample_weight=None):
        real_features = self.encoder(real_images, training=False)
        generated_features = self.encoder(generated_images, training=False)

        # 使用两组特征计算多项式核
        kernel_real = self.polynomial_kernel(real_features, real_features)
        kernel_generated = self.polynomial_kernel(
            generated_features, generated_features
        )
        kernel_cross = self.polynomial_kernel(real_features, generated_features)

        # 使用平均核值估计平方最大均值差异
        batch_size = tf.shape(real_features)[0]
        batch_size_f = tf.cast(batch_size, dtype=tf.float32)
        mean_kernel_real = tf.reduce_sum(kernel_real * (1.0 - tf.eye(batch_size))) / (
            batch_size_f * (batch_size_f - 1.0)
        )
        mean_kernel_generated = tf.reduce_sum(
            kernel_generated * (1.0 - tf.eye(batch_size))
        ) / (batch_size_f * (batch_size_f - 1.0))
        mean_kernel_cross = tf.reduce_mean(kernel_cross)
        kid = mean_kernel_real + mean_kernel_generated - 2.0 * mean_kernel_cross

        # 更新平均 KID 估计
        self.kid_tracker.update_state(kid)

    def result(self):
        return self.kid_tracker.result()

    def reset_state(self):
        self.kid_tracker.reset_state()

自适应判别器增强

StyleGAN2-ADA的作者提出在训练过程中自适应地改变增强概率。尽管在论文中有不同的解释,他们使用积分控制来保持判别器在真实图像上的准确率接近目标值。请注意,他们控制的变量实际上是判别器logits的平均符号(论文中的r_t),这对应于2 * 准确率 - 1。

这种方法需要两个超参数:

  1. target_accuracy: 判别器在真实图像上的准确率目标值。我建议选择80-90%的范围。
  2. integration_steps: 将100%的准确率误差转化为100%的增强概率增加所需的更新步数。为了给出直观的感受,这定义了增强概率变化的速度。我建议将其设置为相对较高的值(在本例中为1000),以便增强强度仅能慢慢调整。

这种程序的主要动机是目标准确率的最佳值在不同数据集大小之间是类似的(见论文中的图4和图5),因此不需要重新调整,因为这个过程可以自动在需要时施加更强的数据增强。

# "硬 sigmoid",对从logits计算二元准确率有用
def step(values):
    # 负值 -> 0.0,正值 -> 1.0
    return 0.5 * (1.0 + tf.sign(values))


# 以在训练过程中动态更新的概率增强图像
class AdaptiveAugmenter(keras.Model):
    def __init__(self):
        super().__init__()

        # 存储图像被增强的当前概率
        self.probability = tf.Variable(0.0)

        # 论文中每一层上方所示的相应增强名称
        # 作者展示(见图4),在低数据情况下,blitting和几何增强最为有效
        self.augmenter = keras.Sequential(
            [
                layers.InputLayer(input_shape=(image_size, image_size, 3)),
                # blitting/x-flip:
                layers.RandomFlip("horizontal"),
                # blitting/整数平移:
                layers.RandomTranslation(
                    height_factor=max_translation,
                    width_factor=max_translation,
                    interpolation="nearest",
                ),
                # geometric/旋转:
                layers.RandomRotation(factor=max_rotation),
                # geometric/等比和非等比缩放:
                layers.RandomZoom(
                    height_factor=(-max_zoom, 0.0), width_factor=(-max_zoom, 0.0)
                ),
            ],
            name="adaptive_augmenter",
        )

    def call(self, images, training):
        if training:
            augmented_images = self.augmenter(images, training)

            # 在训练期间,基于self.probability选择原始图像或增强图像
            augmentation_values = tf.random.uniform(
                shape=(batch_size, 1, 1, 1), minval=0.0, maxval=1.0
            )
            augmentation_bools = tf.math.less(augmentation_values, self.probability)

            images = tf.where(augmentation_bools, augmented_images, images)
        return images

    def update(self, real_logits):
        current_accuracy = tf.reduce_mean(step(real_logits))

        # 增强概率基于判别器在真实图像上的准确率进行更新
        accuracy_error = current_accuracy - target_accuracy
        self.probability.assign(
            tf.clip_by_value(
                self.probability + accuracy_error / integration_steps, 0.0, 1.0
            )
        )

网络架构

在这里我们指定两个网络的架构:

  • 生成器:将随机向量映射到一张图像,尽可能真实
  • 判别器:将图像映射到一个标量分数,对于真实图像应该较高,对于生成图像应该较低

GAN对网络架构敏感,我在这个例子中实现了DCGAN架构,因为它在训练过程中相对稳定,同时实现较简单。我们在整个网络中使用恒定数量的滤波器,在生成器的最后一层使用sigmoid而不是tanh,并使用默认初始化而不是随机正态作为进一步的简化。

作为一种良好的实践,我们在批量归一化层中禁用了可学习的缩放参数,因为一方面后续的relu + 卷积层使其多余(如前所述)。 文档)。但是也因为根据理论,在使用谱归一化(第4.1节)时,它应该被禁用,这在这里未使用,但在GAN中很常见。我们还在全连接层和卷积层中禁用偏置,因为后续的批归一化使其变得多余。

# DCGAN 生成器
def get_generator():
    noise_input = keras.Input(shape=(noise_size,))
    x = layers.Dense(4 * 4 * width, use_bias=False)(noise_input)
    x = layers.BatchNormalization(scale=False)(x)
    x = layers.ReLU()(x)
    x = layers.Reshape(target_shape=(4, 4, width))(x)
    for _ in range(depth - 1):
        x = layers.Conv2DTranspose(
            width, kernel_size=4, strides=2, padding="same", use_bias=False,
        )(x)
        x = layers.BatchNormalization(scale=False)(x)
        x = layers.ReLU()(x)
    image_output = layers.Conv2DTranspose(
        3, kernel_size=4, strides=2, padding="same", activation="sigmoid",
    )(x)

    return keras.Model(noise_input, image_output, name="generator")


# DCGAN 判别器
def get_discriminator():
    image_input = keras.Input(shape=(image_size, image_size, 3))
    x = image_input
    for _ in range(depth):
        x = layers.Conv2D(
            width, kernel_size=4, strides=2, padding="same", use_bias=False,
        )(x)
        x = layers.BatchNormalization(scale=False)(x)
        x = layers.LeakyReLU(alpha=leaky_relu_slope)(x)
    x = layers.Flatten()(x)
    x = layers.Dropout(dropout_rate)(x)
    output_score = layers.Dense(1)(x)

    return keras.Model(image_input, output_score, name="discriminator")

GAN 模型

class GAN_ADA(keras.Model):
    def __init__(self):
        super().__init__()

        self.augmenter = AdaptiveAugmenter()
        self.generator = get_generator()
        self.ema_generator = keras.models.clone_model(self.generator)
        self.discriminator = get_discriminator()

        self.generator.summary()
        self.discriminator.summary()

    def compile(self, generator_optimizer, discriminator_optimizer, **kwargs):
        super().compile(**kwargs)

        # 两个网络的优化器分开
        self.generator_optimizer = generator_optimizer
        self.discriminator_optimizer = discriminator_optimizer

        self.generator_loss_tracker = keras.metrics.Mean(name="g_loss")
        self.discriminator_loss_tracker = keras.metrics.Mean(name="d_loss")
        self.real_accuracy = keras.metrics.BinaryAccuracy(name="real_acc")
        self.generated_accuracy = keras.metrics.BinaryAccuracy(name="gen_acc")
        self.augmentation_probability_tracker = keras.metrics.Mean(name="aug_p")
        self.kid = KID()

    @property
    def metrics(self):
        return [
            self.generator_loss_tracker,
            self.discriminator_loss_tracker,
            self.real_accuracy,
            self.generated_accuracy,
            self.augmentation_probability_tracker,
            self.kid,
        ]

    def generate(self, batch_size, training):
        latent_samples = tf.random.normal(shape=(batch_size, noise_size))
        # 在推理时使用 ema_generator
        if training:
            generated_images = self.generator(latent_samples, training)
        else:
            generated_images = self.ema_generator(latent_samples, training)
        return generated_images

    def adversarial_loss(self, real_logits, generated_logits):
        # 这通常被称为非饱和GAN损失

        real_labels = tf.ones(shape=(batch_size, 1))
        generated_labels = tf.zeros(shape=(batch_size, 1))

        # 生成器试图生成被判别器视为真实的图像
        generator_loss = keras.losses.binary_crossentropy(
            real_labels, generated_logits, from_logits=True
        )
        # 判别器试图区分图像是真实的还是生成的
        discriminator_loss = keras.losses.binary_crossentropy(
            tf.concat([real_labels, generated_labels], axis=0),
            tf.concat([real_logits, generated_logits], axis=0),
            from_logits=True,
        )

        return tf.reduce_mean(generator_loss), tf.reduce_mean(discriminator_loss)

    def train_step(self, real_images):
        real_images = self.augmenter(real_images, training=True)

        # 使用持久的梯度带,因为梯度将被计算两次
        with tf.GradientTape(persistent=True) as tape:
            generated_images = self.generate(batch_size, training=True)
            # 通过图像增强计算梯度
            generated_images = self.augmenter(generated_images, training=True)

            # 对真实图像和生成图像进行分开前向传播,这意味着
            # 批归一化是单独应用的
            real_logits = self.discriminator(real_images, training=True)
            generated_logits = self.discriminator(generated_images, training=True)

            generator_loss, discriminator_loss = self.adversarial_loss(
                real_logits, generated_logits
            )

        # 计算梯度并更新权重
        generator_gradients = tape.gradient(
            generator_loss, self.generator.trainable_weights
        )
        discriminator_gradients = tape.gradient(
            discriminator_loss, self.discriminator.trainable_weights
        )
        self.generator_optimizer.apply_gradients(
            zip(generator_gradients, self.generator.trainable_weights)
        )
        self.discriminator_optimizer.apply_gradients(
            zip(discriminator_gradients, self.discriminator.trainable_weights)
        )

        # 基于判别器的性能更新增强概率
        self.augmenter.update(real_logits)

        self.generator_loss_tracker.update_state(generator_loss)
        self.discriminator_loss_tracker.update_state(discriminator_loss)
        self.real_accuracy.update_state(1.0, step(real_logits))
        self.generated_accuracy.update_state(0.0, step(generated_logits))
        self.augmentation_probability_tracker.update_state(self.augmenter.probability)

        # 跟踪生成器权重的指数移动平均,以减少
        # 生成质量中的方差
        for weight, ema_weight in zip(
            self.generator.weights, self.ema_generator.weights
        ):
            ema_weight.assign(ema * ema_weight + (1 - ema) * weight)

        # KID 在训练阶段不进行测量以提高计算效率
        return {m.name: m.result() for m in self.metrics[:-1]}

    def test_step(self, real_images):
        generated_images = self.generate(batch_size, training=False)

        self.kid.update_state(real_images, generated_images)

        # 只有 KID 在评估阶段被测量以提高计算效率
        return {self.kid.name: self.kid.result()}

    def plot_images(self, epoch=None, logs=None, num_rows=3, num_cols=6, interval=5):
        # 绘制随机生成的图像以视觉评估生成质量
        if epoch is None or (epoch + 1) % interval == 0:
            num_images = num_rows * num_cols
            generated_images = self.generate(num_images, training=False)

            plt.figure(figsize=(num_cols * 2.0, num_rows * 2.0))
            for row in range(num_rows):
                for col in range(num_cols):
                    index = row * num_cols + col
                    plt.subplot(num_rows, num_cols, index + 1)
                    plt.imshow(generated_images[index])
                    plt.axis("off")
            plt.tight_layout()
            plt.show()
            plt.close()

训练

可以从训练过程中的指标看出,如果真实准确率(对真实图像的判别器准确率)低于目标准确率,增强概率会增加,反之亦然。根据我的经验,在正常的GAN训练过程中,判别器的准确率应保持在80-95%的范围内。低于这个范围,判别器过弱;高于这个范围,判别器过强。

请注意,我们跟踪生成器权重的指数移动平均值,并在图像生成和KID评估中使用它。

# 创建和编译模型
model = GAN_ADA()
model.compile(
    generator_optimizer=keras.optimizers.Adam(learning_rate, beta_1),
    discriminator_optimizer=keras.optimizers.Adam(learning_rate, beta_1),
)

# 根据验证KID指标保存最佳模型
checkpoint_path = "gan_model"
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path,
    save_weights_only=True,
    monitor="val_kid",
    mode="min",
    save_best_only=True,
)

# 运行训练并定期绘制生成的图像
model.fit(
    train_dataset,
    epochs=num_epochs,
    validation_data=val_dataset,
    callbacks=[
        keras.callbacks.LambdaCallback(on_epoch_end=model.plot_images),
        checkpoint_callback,
    ],
)
模型: "generator"
_________________________________________________________________
层 (类型)                  输出形状                   参数 #   
=================================================================
input_2 (输入层)           [(None, 64)]               0         
_________________________________________________________________
dense (全连接)             (None, 2048)               131072    
_________________________________________________________________
batch_normalization (批正  (None, 2048)               6144      
_________________________________________________________________
re_lu (ReLU)               (None, 2048)               0         
_________________________________________________________________
reshape (重塑)             (None, 4, 4, 128)          0         
_________________________________________________________________
conv2d_transpose (反卷积) (None, 8, 8, 128)           262144    
_________________________________________________________________
batch_normalization_1 (批正 (None, 8, 8, 128)          384       
_________________________________________________________________
re_lu_1 (ReLU)             (None, 8, 8, 128)          0         
_________________________________________________________________
conv2d_transpose_1 (反卷积)(None, 16, 16, 128)        262144    
_________________________________________________________________
batch_normalization_2 (批正 (None, 16, 16, 128)        384       
_________________________________________________________________
re_lu_2 (ReLU)             (None, 16, 16, 128)        0         
_________________________________________________________________
conv2d_transpose_2 (反卷积)(None, 32, 32, 128)        262144    
_________________________________________________________________
batch_normalization_3 (批正 (None, 32, 32, 128)        384       
_________________________________________________________________
re_lu_3 (ReLU)             (None, 32, 32, 128)        0         
_________________________________________________________________
conv2d_transpose_3 (反卷积)(None, 64, 64, 3)          6147      
=================================================================
总参数: 930,947
可训练参数: 926,083
不可训练参数: 4,864
_________________________________________________________________
模型: "discriminator"
_________________________________________________________________
层 (类型)                  输出形状                   参数 #   
=================================================================
input_3 (输入层)           [(None, 64, 64, 3)]        0         
_________________________________________________________________
conv2d (卷积)              (None, 32, 32, 128)        6144      
_________________________________________________________________
batch_normalization_4 (批正 (None, 32, 32, 128)        384       
_________________________________________________________________
leaky_re_lu (泄漏ReLU)      (None, 32, 32, 128)        0         
_________________________________________________________________
conv2d_1 (卷积)            (None, 16, 16, 128)        262144    
_________________________________________________________________
batch_normalization_5 (批正 (None, 16, 16, 128)        384       
_________________________________________________________________
leaky_re_lu_1 (泄漏ReLU)    (None, 16, 16, 128)        0         
_________________________________________________________________
conv2d_2 (卷积)            (None, 8, 8, 128)          262144    
_________________________________________________________________
batch_normalization_6 (批正 (None, 8, 8, 128)          384       
_________________________________________________________________
leaky_re_lu_2 (泄漏ReLU)    (None, 8, 8, 128)          0         
_________________________________________________________________
conv2d_3 (卷积)            (None, 4, 4, 128)          262144    
_________________________________________________________________
batch_normalization_7 (批正 (None, 4, 4, 128)          384       
_________________________________________________________________
leaky_re_lu_3 (泄漏ReLU)    (None, 4, 4, 128)          0         
_________________________________________________________________
flatten (展平)             (None, 2048)               0         
_________________________________________________________________
dropout (丢弃)             (None, 2048)               0         
_________________________________________________________________
dense_1 (全连接)           (None, 1)                  2049      
=================================================================
总参数: 796,161
可训练参数: 795,137
不可训练参数: 1,024
_________________________________________________________________
第 1 轮/10
46/46 [==============================] - 36s 307ms/步 - g_loss: 3.3293 - d_loss: 0.1576 - real_acc: 0.9387 - gen_acc: 0.9579 - aug_p: 0.0020 - val_kid: 9.0999
第 2 轮/10
46/46 [==============================] - 10s 215ms/步 - g_loss: 4.9824 - d_loss: 0.0912 - real_acc: 0.9704 - gen_acc: 0.9798 - aug_p: 0.0077 - val_kid: 8.3523
第 3 轮/10
46/46 [==============================] - 10s 218ms/步 - g_loss: 5.0587 - d_loss: 0.1248 - real_acc: 0.9530 - gen_acc: 0.9625 - aug_p: 0.0131 - val_kid: 6.8116
第 4 轮/10
46/46 [==============================] - 10s 221ms/步 - g_loss: 4.2580 - d_loss: 0.1002 - real_acc: 0.9686 - gen_acc: 0.9740 - aug_p: 0.0179 - val_kid: 5.2327
第 5 轮/10
46/46 [==============================] - 10s 225ms/步 - g_loss: 4.6022 - d_loss: 0.0847 - real_acc: 0.9655 - gen_acc: 0.9852 - aug_p: 0.0234 - val_kid: 3.9004

png

第6轮/共10轮
46/46 [==============================] - 10s 224ms/step - g_loss: 4.9362 - d_loss: 0.0671 - real_acc: 0.9791 - gen_acc: 0.9895 - aug_p: 0.0291 - val_kid: 6.6020
第7轮/共10轮
46/46 [==============================] - 10s 222ms/step - g_loss: 4.4272 - d_loss: 0.1184 - real_acc: 0.9570 - gen_acc: 0.9657 - aug_p: 0.0345 - val_kid: 3.3644
第8轮/共10轮
46/46 [==============================] - 10s 220ms/step - g_loss: 4.5060 - d_loss: 0.1635 - real_acc: 0.9421 - gen_acc: 0.9594 - aug_p: 0.0392 - val_kid: 3.1381
第9轮/共10轮
46/46 [==============================] - 10s 219ms/step - g_loss: 3.8264 - d_loss: 0.1667 - real_acc: 0.9383 - gen_acc: 0.9484 - aug_p: 0.0433 - val_kid: 2.9423
第10轮/共10轮
46/46 [==============================] - 10s 219ms/step - g_loss: 3.4063 - d_loss: 0.1757 - real_acc: 0.9314 - gen_acc: 0.9475 - aug_p: 0.0473 - val_kid: 2.9112

png

<keras.callbacks.History 在 0x7fefcc2cb9d0>

推理

# 加载最佳模型并生成图像
model.load_weights(checkpoint_path)
model.plot_images()

png


结果

通过训练400个轮次(在Colab笔记本中通常需要2-3小时),可以使用此代码示例获得高质量的图像生成。

在400个轮次训练中随机批次图像的演变(ema=0.999以获得动画平滑度): birds evolution gif

在一批选定图像之间的潜在空间插值: birds interpolation gif

我还建议尝试在其他数据集上进行训练,例如CelebA。根据我的经验,在不更改任何超参数的情况下也可以获得良好的结果(尽管可能不需要判别器增强)。


GAN小贴士和技巧

我在这个示例中的目标是找到实现简单性和生成质量之间的良好折衷。在准备过程中,我使用this repository进行了许多消融实验。

在本节中,我列出了所学到的经验教训以及我的建议,按主观重要性排序。

我建议查看DCGAN论文、这个NeurIPS演讲以及这个大规模GAN研究,了解他人在这一主题上的看法。

架构建议

  • 分辨率:在更高分辨率下训练GAN会更困难,建议最初在32x32或64x64分辨率下进行实验。
  • 初始化:如果您在训练早期看到强烈的彩色模式,初始化可能是问题所在。将层的kernel_initializer参数设置为random normal,并将标准差减小(推荐值:0.02,遵循DCGAN)直到问题消失。
  • 上采样:生成器中有两种主要的上采样方法。转置卷积更快,但可能导致棋盘伪影,可以通过使用可被步幅整除的内核大小来减少(推荐的内核大小是4,步幅为2)。上采样 + 标准卷积的质量可能稍低,但棋盘伪影不是问题。 我建议使用最近邻插值,而不是双线性插值。
  • 判别器中的批归一化:有时影响很大,我建议两种方式都尝试。
  • 谱归一化:一种流行的GAN训练技术,可以帮助稳定性。我建议与之一起禁用批归一化的可学习缩放参数。
  • 残差连接:尽管残差判别器的行为相似,但根据我的经验,残差生成器更难训练。然而,它们对于训练大型和深度架构是必要的。我建议从非残差架构开始。
  • 丢弃法:根据我的经验,在判别器的最后一层之前使用丢弃法可以改善生成质量。推荐的丢弃率低于0.5。
  • 泄漏ReLU:使用泄漏 ReLU激活函数在判别器中,以减小其梯度稀疏性。推荐的斜率/α为0.2,遵循DCGAN。

算法提示

  • 损失函数:多年来提出了许多损失函数用于训练GAN,承诺提高性能和稳定性。我在这个仓库中实现了其中5个,它们的表现与这篇GAN研究一致:没有一种损失函数能持续优于默认的非饱和GAN损失。我建议默认使用该损失函数。
  • Adam的beta_1参数:Adam中的beta_1参数可以被理解为平均梯度估计的动量。使用0.5甚至0.0代替默认的0.9值是DCGAN中提出的重要建议。使用默认值将无法获得此示例的效果。
  • 生成图像和真实图像的单独批量归一化:判别器的前向传播应对生成图像和真实图像分别进行。否则可能会导致伪影(在我的例子中为45度条纹)并降低性能。
  • 生成器权重的指数移动平均:这有助于降低KID测量的方差,并有助于在训练过程中平均快速的调色板变化。
  • 生成器和判别器不同的学习率:如果有资源,可以帮助单独调节两个网络的学习率。一个类似的想法是对一个网络(通常是判别器)的权重在另一个网络更新的每次迭代中进行多次更新。我建议对两个网络使用相同的学习率2e-4(Adam),遵循DCGAN,并作为默认情况下只更新它们一次。
  • 标签噪音单侧标签平滑(对真实标签使用小于1.0的值)或在标签中添加噪音可以正则化判别器,避免其过于自信;然而在我的案例中,这并没有提高性能。
  • 自适应数据增强:由于它为训练过程增加了另一个动态组件,默认情况下将其禁用,只有在其他组件已经运行良好时才启用。

相关工作

其他与GAN相关的Keras代码示例:

现代GAN架构线:

关于判别器数据增强的相关论文: 123

关于GAN的最新文献概述:讲座