代码示例 / 生成式深度学习 / 向量量化变分自编码器

向量量化变分自编码器

作者: Sayak Paul
创建日期: 2021/07/21
最后修改: 2021/06/27

在Colab中查看 GitHub源代码

描述: 训练VQ-VAE以进行图像重建和代码本采样生成。

在本示例中,我们开发了一个向量量化变分自编码器(VQ-VAE)。 VQ-VAE是由van der Oord等人在神经离散表示学习中提出的。在标准的变分自编码器(VAE)中,潜在空间是连续的,并且是从高斯分布中抽样的。通过梯度下降学习这样的连续分布通常更困难。另一方面,VQ-VAE在离散潜在空间上操作,使优化问题更简单。它通过维护一个离散的代码本来实现这一点。代码本通过离散化连续嵌入与编码输出之间的距离来开发。这些离散的代码字然后被输入到解码器中,解码器经过训练以生成重建样本。

有关VQ-VAE的概述,请参考原始论文和这个视频讲解。 如果您需要回顾VAE,可以参考这本书的章节。 VQ-VAE是DALL-E背后的主要方法之一,代码本的思想也被用于VQ-GANs

本示例使用了来自DeepMind的官方VQ-VAE教程的实现细节。

需求

要运行此示例,您需要TensorFlow 2.5或更高版本,以及可以使用以下命令安装的TensorFlow概率库。

!pip install -q tensorflow-probability

导入

import numpy as np
import matplotlib.pyplot as plt

from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_probability as tfp
import tensorflow as tf

VectorQuantizer

首先,我们实现一个自定义层用于向量量化器,该层在编码器和解码器之间。假设编码器的输出形状为(batch_size, height, width, num_filters)。向量量化器将首先扁平化此输出,仅保持num_filters维度不变。因此,形状将变为(batch_size * height * width, num_filters)。这样做的原因是将滤波器的总数量视为潜在嵌入的大小。

然后初始化一个嵌入表以学习代码本。我们测量扁平化编码器输出和代码本的代码字之间的L2归一化距离。我们选择产生最小距离的代码,并应用独热编码以实现量化。通过这种方式,产生与相应编码器输出的最小距离的代码被映射为1,而其余代码被映射为0。

由于量化过程是不可微的,我们在解码器和编码器之间应用一个直通估计器,以便将解码器的梯度直接传播到编码器。由于编码器和解码器共享相同的通道空间,解码器的梯度对于编码器仍然是有意义的。

class VectorQuantizer(layers.Layer):
    def __init__(self, num_embeddings, embedding_dim, beta=0.25, **kwargs):
        super().__init__(**kwargs)
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings

        # `beta` 参数最好保持在 [0.25, 2] 之间。
        self.beta = beta

        # 初始化我们将进行量化的嵌入。
        w_init = tf.random_uniform_initializer()
        self.embeddings = tf.Variable(
            initial_value=w_init(
                shape=(self.embedding_dim, self.num_embeddings), dtype="float32"
            ),
            trainable=True,
            name="embeddings_vqvae",
        )

    def call(self, x):
        # 计算输入的输入形状,然后在保持 `embedding_dim` 不变的情况下将输入展平。
        input_shape = tf.shape(x)
        flattened = tf.reshape(x, [-1, self.embedding_dim])

        # 量化。
        encoding_indices = self.get_code_indices(flattened)
        encodings = tf.one_hot(encoding_indices, self.num_embeddings)
        quantized = tf.matmul(encodings, self.embeddings, transpose_b=True)

        # 将量化值重新形状为原始输入形状
        quantized = tf.reshape(quantized, input_shape)

        # 计算向量量化损失并将其添加到层中。你可以在这里了解更多关于将损失添加到不同层的信息:
        # https://keras.io/guides/making_new_layers_and_models_via_subclassing/. 请查看
        # 原论文以了解损失函数的公式。
        commitment_loss = tf.reduce_mean((tf.stop_gradient(quantized) - x) ** 2)
        codebook_loss = tf.reduce_mean((quantized - tf.stop_gradient(x)) ** 2)
        self.add_loss(self.beta * commitment_loss + codebook_loss)

        # 直通估计器。
        quantized = x + tf.stop_gradient(quantized - x)
        return quantized

    def get_code_indices(self, flattened_inputs):
        # 计算输入与代码之间的 L2 归一化距离。
        similarity = tf.matmul(flattened_inputs, self.embeddings)
        distances = (
            tf.reduce_sum(flattened_inputs ** 2, axis=1, keepdims=True)
            + tf.reduce_sum(self.embeddings ** 2, axis=0)
            - 2 * similarity
        )

        # 推导最小距离的索引。
        encoding_indices = tf.argmin(distances, axis=1)
        return encoding_indices

关于直通估计的说明:

这行代码执行了直通估计部分:quantized = x + tf.stop_gradient(quantized - x)。在反向传播过程中,(quantized - x)将不会包含在计算图中,并获得的quantized的梯度将被复制给inputs。感谢这个视频帮助我理解这项技术。


编码器和解码器

现在介绍VQ-VAE的编码器和解码器。我们将它们保持较小,以便它们的容量适合MNIST数据集。编码器和解码器的实现来自 这个例子

请注意,_除了ReLU以外的激活函数_可能不适用于量化架构中的编码器和解码器层:例如,泄漏ReLU激活的层证明了训练困难,导致模型难以恢复的间歇性损失峰值。

def get_encoder(latent_dim=16):
    encoder_inputs = keras.Input(shape=(28, 28, 1))
    x = layers.Conv2D(32, 3, activation="relu", strides=2, padding="same")(
        encoder_inputs
    )
    x = layers.Conv2D(64, 3, activation="relu", strides=2, padding="same")(x)
    encoder_outputs = layers.Conv2D(latent_dim, 1, padding="same")(x)
    return keras.Model(encoder_inputs, encoder_outputs, name="encoder")


def get_decoder(latent_dim=16):
    latent_inputs = keras.Input(shape=get_encoder(latent_dim).output.shape[1:])
    x = layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")(
        latent_inputs
    )
    x = layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")(x)
    decoder_outputs = layers.Conv2DTranspose(1, 3, padding="same")(x)
    return keras.Model(latent_inputs, decoder_outputs, name="decoder")

独立的VQ-VAE模型

def get_vqvae(latent_dim=16, num_embeddings=64):
    vq_layer = VectorQuantizer(num_embeddings, latent_dim, name="vector_quantizer")
    encoder = get_encoder(latent_dim)
    decoder = get_decoder(latent_dim)
    inputs = keras.Input(shape=(28, 28, 1))
    encoder_outputs = encoder(inputs)
    quantized_latents = vq_layer(encoder_outputs)
    reconstructions = decoder(quantized_latents)
    return keras.Model(inputs, reconstructions, name="vq_vae")


get_vqvae().summary()
模型: "vq_vae"
_________________________________________________________________
层 (类型)                   输出形状                参数 #
=================================================================
input_4 (输入层)           [(None, 28, 28, 1)]       0         
_________________________________________________________________
encoder (功能模块)         (None, 7, 7, 16)          19856     
_________________________________________________________________
vector_quantizer (矢量量化器) (None, 7, 7, 16)          1024      
_________________________________________________________________
decoder (功能模块)         (None, 28, 28, 1)         28033     
=================================================================
总参数: 48,913
可训练参数: 48,913
不可训练参数: 0
_________________________________________________________________

请注意,编码器的输出通道应与矢量量化器的latent_dim匹配。


VQVAETrainer中封装训练循环

class VQVAETrainer(keras.models.Model):
    def __init__(self, train_variance, latent_dim=32, num_embeddings=128, **kwargs):
        super().__init__(**kwargs)
        self.train_variance = train_variance
        self.latent_dim = latent_dim
        self.num_embeddings = num_embeddings

        self.vqvae = get_vqvae(self.latent_dim, self.num_embeddings)

        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(
            name="reconstruction_loss"
        )
        self.vq_loss_tracker = keras.metrics.Mean(name="vq_loss")

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.vq_loss_tracker,
        ]

    def train_step(self, x):
        with tf.GradientTape() as tape:
            # VQ-VAE的输出。
            reconstructions = self.vqvae(x)

            # 计算损失。
            reconstruction_loss = (
                tf.reduce_mean((x - reconstructions) ** 2) / self.train_variance
            )
            total_loss = reconstruction_loss + sum(self.vqvae.losses)

        # 反向传播。
        grads = tape.gradient(total_loss, self.vqvae.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.vqvae.trainable_variables))

        # 损失跟踪。
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.vq_loss_tracker.update_state(sum(self.vqvae.losses))

        # 记录结果。
        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "vqvae_loss": self.vq_loss_tracker.result(),
        }

加载和预处理 MNIST 数据集

(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()

x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)
x_train_scaled = (x_train / 255.0) - 0.5
x_test_scaled = (x_test / 255.0) - 0.5

data_variance = np.var(x_train / 255.0)
从 https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz 下载数据
11493376/11490434 [==============================] - 0s 0us/step

训练 VQ-VAE 模型

vqvae_trainer = VQVAETrainer(data_variance, latent_dim=16, num_embeddings=128)
vqvae_trainer.compile(optimizer=keras.optimizers.Adam())
vqvae_trainer.fit(x_train_scaled, epochs=30, batch_size=128)
第 1/30 轮
469/469 [==============================] - 18s 6ms/step - loss: 2.2962 - reconstruction_loss: 0.3869 - vqvae_loss: 1.5950
第 2/30 轮
469/469 [==============================] - 3s 6ms/step - loss: 2.2980 - reconstruction_loss: 0.1692 - vqvae_loss: 2.1108
第 3/30 轮
469/469 [==============================] - 3s 6ms/step - loss: 1.1356 - reconstruction_loss: 0.1281 - vqvae_loss: 0.9997
第 4/30 轮
469/469 [==============================] - 3s 6ms/step - loss: 0.6112 - reconstruction_loss: 0.1030 - vqvae_loss: 0.5031
第 5/30 轮
469/469 [==============================] - 3s 6ms/step - loss: 0.4375 - reconstruction_loss: 0.0883 - vqvae_loss: 0.3464
第 6/30 轮
469/469 [==============================] - 3s 6ms/step - loss: 0.3579 - reconstruction_loss: 0.0788 - vqvae_loss: 0.2775
第 7/30 轮
469/469 [==============================] - 3s 5ms/step - loss: 0.3197 - reconstruction_loss: 0.0725 - vqvae_loss: 0.2457
第 8/30 轮
469/469 [==============================] - 3s 5ms/step - loss: 0.2960 - reconstruction_loss: 0.0673 - vqvae_loss: 0.2277
第 9/30 轮
469/469 [==============================] - 3s 5ms/step - loss: 0.2798 - reconstruction_loss: 0.0640 - vqvae_loss: 0.2152
第 10/30 轮
469/469 [==============================] - 3s 5ms/step - loss: 0.2681 - reconstruction_loss: 0.0612 - vqvae_loss: 0.2061
第 11/30 轮
469/469 [==============================] - 3s 6ms/step - loss: 0.2578 - reconstruction_loss: 0.0590 - vqvae_loss: 0.1986
第 12/30 轮
469/469 [==============================] - 3s 6ms/step - loss: 0.2551 - reconstruction_loss: 0.0574 - vqvae_loss: 0.1974
第 13/30 轮
469/469 [==============================] - 3s 6ms/step - loss: 0.2526 - reconstruction_loss: 0.0560 - vqvae_loss: 0.1961
第 14/30 轮
469/469 [==============================] - 3s 6ms/step - loss: 0.2485 - reconstruction_loss: 0.0546 - vqvae_loss: 0.1936
第 15/30 轮
469/469 [==============================] - 3s 6ms/step - loss: 0.2462 - reconstruction_loss: 0.0533 - vqvae_loss: 0.1926
第 16/30 轮
469/469 [==============================] - 3s 6ms/step - loss: 0.2445 - reconstruction_loss: 0.0523 - vqvae_loss: 0.1920
第 17/30 轮
469/469 [==============================] - 3s 6ms/step - loss: 0.2427 - reconstruction_loss: 0.0515 - vqvae_loss: 0.1911
第 18/30 轮
469/469 [==============================] - 3s 6ms/step - loss: 0.2405 - reconstruction_loss: 0.0505 - vqvae_loss: 0.1898
第 19/30 轮
469/469 [==============================] - 3s 6ms/step - loss: 0.2368 - reconstruction_loss: 0.0495 - vqvae_loss: 0.1871
第 20/30 轮
469/469 [==============================] - 3s 5ms/step - loss: 0.2310 - reconstruction_loss: 0.0486 - vqvae_loss: 0.1822
第 21/30 轮
469/469 [==============================] - 3s 5ms/step - loss: 0.2245 - reconstruction_loss: 0.0475 - vqvae_loss: 0.1769
第 22/30 轮
469/469 [==============================] - 3s 5ms/step - loss: 0.2205 - reconstruction_loss: 0.0469 - vqvae_loss: 0.1736
第 23/30 轮
469/469 [==============================] - 3s 5ms/step - loss: 0.2195 - reconstruction_loss: 0.0465 - vqvae_loss: 0.1730
第 24/30 轮
469/469 [==============================] - 3s 5ms/step - loss: 0.2187 - reconstruction_loss: 0.0461 - vqvae_loss: 0.1726
第 25/30 轮
469/469 [==============================] - 3s 5ms/step - loss: 0.2180 - reconstruction_loss: 0.0458 - vqvae_loss: 0.1721
第 26/30 轮
469/469 [==============================] - 3s 5ms/step - loss: 0.2163 - reconstruction_loss: 0.0454 - vqvae_loss: 0.1709
第 27/30 轮
469/469 [==============================] - 3s 5ms/step - loss: 0.2156 - reconstruction_loss: 0.0452 - vqvae_loss: 0.1704
第 28/30 轮
469/469 [==============================] - 3s 5ms/step - loss: 0.2146 - reconstruction_loss: 0.0449 - vqvae_loss: 0.1696
第 29/30 轮
469/469 [==============================] - 3s 5ms/step - loss: 0.2139 - reconstruction_loss: 0.0447 - vqvae_loss: 0.1692
第 30/30 轮
469/469 [==============================] - 3s 5ms/step - loss: 0.2127 - reconstruction_loss: 0.0444 - vqvae_loss: 0.1682

<tensorflow.python.keras.callbacks.History at 0x7f96402f4e50>

测试集上的重建结果

def show_subplot(original, reconstructed):
    plt.subplot(1, 2, 1)
    plt.imshow(original.squeeze() + 0.5)
    plt.title("原始图像")
    plt.axis("off")

    plt.subplot(1, 2, 2)
    plt.imshow(reconstructed.squeeze() + 0.5)
    plt.title("重建图像")
    plt.axis("off")

    plt.show()


trained_vqvae_model = vqvae_trainer.vqvae
idx = np.random.choice(len(x_test_scaled), 10)
test_images = x_test_scaled[idx]
reconstructions_test = trained_vqvae_model.predict(test_images)

for test_image, reconstructed_image in zip(test_images, reconstructions_test):
    show_subplot(test_image, reconstructed_image)

png

png

png

png

png

png

png

png

png

png

这些结果看起来不错。鼓励您尝试不同的超参数(尤其是嵌入的数量和嵌入的维度),并观察它们如何影响结果。


可视化离散代码

encoder = vqvae_trainer.vqvae.get_layer("encoder")
quantizer = vqvae_trainer.vqvae.get_layer("vector_quantizer")

encoded_outputs = encoder.predict(test_images)
flat_enc_outputs = encoded_outputs.reshape(-1, encoded_outputs.shape[-1])
codebook_indices = quantizer.get_code_indices(flat_enc_outputs)
codebook_indices = codebook_indices.numpy().reshape(encoded_outputs.shape[:-1])

for i in range(len(test_images)):
    plt.subplot(1, 2, 1)
    plt.imshow(test_images[i].squeeze() + 0.5)
    plt.title("原始图像")
    plt.axis("off")

    plt.subplot(1, 2, 2)
    plt.imshow(codebook_indices[i])
    plt.title("代码")
    plt.axis("off")
    plt.show()

png

png

png

png

png

png

png

png

png

png

上面的图显示,离散代码能够捕捉到数据集中的一些规律。现在,我们如何从这个代码簿中采样以创建新图像呢?由于这些代码是离散的,并且我们对它们施加了分类分布,因此在我们能够生成可以交给解码器的可能代码序列之前,无法用它们生成任何有意义的东西。 作者使用 PixelCNN 来训练这些代码,以便它们可以用作生成新示例的强先验。PixelCNN 在 Conditional Image Generation with PixelCNN Decoders 中由 van der Oord 等人提出。我们将借用来自 this example 的代码,该示例也是由 van der Oord 等人提供。我们借用了 this PixelCNN example中的实现。它是一个自回归生成模型,输出条件依赖于之前的输出。换句话说,PixelCNN 逐像素生成图像。然而,在此示例中,其任务是生成代码簿索引,而不是直接生成像素。训练好的 VQ-VAE 解码器用于将 PixelCNN 生成的索引映射回像素空间。


PixelCNN 超参数

num_residual_blocks = 2
num_pixelcnn_layers = 2
pixelcnn_input_shape = encoded_outputs.shape[1:-1]
print(f"PixelCNN的输入形状: {pixelcnn_input_shape}")
PixelCNN的输入形状: (7, 7)

此输入形状表示编码器所执行的分辨率缩减。使用“相同”填充时,这正好将每个步幅为2的卷积层的输出形状的“分辨率”减半。因此,经过这两个层后,我们最终得到一个在第2轴和第3轴上为7x7的编码器输出张量,第1轴为批量大小,最后一轴为代码本嵌入大小。由于自编码器中的量化层将这些7x7张量映射到代码本的索引,因此这些输出层轴的大小必须与PixelCNN作为输入形状匹配。此架构下PixelCNN的任务是生成可能的7x7代码本索引排列。

请注意,这个形状是针对较大尺寸的图像域进行优化的,同时还有代码本大小。由于PixelCNN是自回归的,它需要依次遍历每个代码本索引,才能生成代码本的新图像。每个步幅为2的(或者更确切地说是步幅(2, 2))卷积层将使图像生成时间减少四倍。然而,请注意,这部分可能存在下限:当重建图像所需的代码数量太少时,解码器缺乏足够的信息来表示图像的细节水平,因此不同的图像域可能需要更大的代码本大小和更大的图像分辨率。 输出质量将会受到影响。至少可以通过使用更大的代码本来在一定程度上进行修正。由于图像生成过程的自回归部分使用代码本索引,因此使用更大的代码本的性能损失要小得多,因为从更大的代码本查找更大尺寸的代码所需的查找时间与迭代更长的代码本索引序列相比要小得多,尽管代码本的大小确实会影响能够通过图像生成过程的批量大小。找到这个权衡的最佳选择可能需要对架构进行一些调整,并且可能会因数据集而异。


PixelCNN模型

大部分内容来自 this example

注意事项

感谢 Rein van 't Veer 通过复制编辑和小的代码清理改进了这个示例。

# 第一层是PixelCNN层。该层简单地
# 基于二维卷积层,但包含掩码。
class PixelConvLayer(layers.Layer):
    def __init__(self, mask_type, **kwargs):
        super().__init__()
        self.mask_type = mask_type
        self.conv = layers.Conv2D(**kwargs)

    def build(self, input_shape):
        # 构建conv2d层以初始化内核变量
        self.conv.build(input_shape)
        # 使用初始化的内核创建掩码
        kernel_shape = self.conv.kernel.get_shape()
        self.mask = np.zeros(shape=kernel_shape)
        self.mask[: kernel_shape[0] // 2, ...] = 1.0
        self.mask[kernel_shape[0] // 2, : kernel_shape[1] // 2, ...] = 1.0
        if self.mask_type == "B":
            self.mask[kernel_shape[0] // 2, kernel_shape[1] // 2, ...] = 1.0

    def call(self, inputs):
        self.conv.kernel.assign(self.conv.kernel * self.mask)
        return self.conv(inputs)


# 接下来,我们构建我们的残差块层。
# 这只是一个普通的残差块,但基于PixelConvLayer。
class ResidualBlock(keras.layers.Layer):
    def __init__(self, filters, **kwargs):
        super().__init__(**kwargs)
        self.conv1 = keras.layers.Conv2D(
            filters=filters, kernel_size=1, activation="relu"
        )
        self.pixel_conv = PixelConvLayer(
            mask_type="B",
            filters=filters // 2,
            kernel_size=3,
            activation="relu",
            padding="same",
        )
        self.conv2 = keras.layers.Conv2D(
            filters=filters, kernel_size=1, activation="relu"
        )

    def call(self, inputs):
        x = self.conv1(inputs)
        x = self.pixel_conv(x)
        x = self.conv2(x)
        return keras.layers.add([inputs, x])


pixelcnn_inputs = keras.Input(shape=pixelcnn_input_shape, dtype=tf.int32)
ohe = tf.one_hot(pixelcnn_inputs, vqvae_trainer.num_embeddings)
x = PixelConvLayer(
    mask_type="A", filters=128, kernel_size=7, activation="relu", padding="same"
)(ohe)

for _ in range(num_residual_blocks):
    x = ResidualBlock(filters=128)(x)

for _ in range(num_pixelcnn_layers):
    x = PixelConvLayer(
        mask_type="B",
        filters=128,
        kernel_size=1,
        strides=1,
        activation="relu",
        padding="valid",
    )(x)

out = keras.layers.Conv2D(
    filters=vqvae_trainer.num_embeddings, kernel_size=1, strides=1, padding="valid"
)(x)

pixel_cnn = keras.Model(pixelcnn_inputs, out, name="pixel_cnn")
pixel_cnn.summary()
模型: "pixel_cnn"
_________________________________________________________________
层 (类型)                 输出形状              参数 #   
=================================================================
input_9 (输入层)         [(None, 7, 7)]            0         
_________________________________________________________________
tf.one_hot (TFOpLambda)      (None, 7, 7, 128)         0         
_________________________________________________________________
pixel_conv_layer (PixelConvL (None, 7, 7, 128)         802944    
_________________________________________________________________
residual_block (ResidualBloc (None, 7, 7, 128)         98624     
_________________________________________________________________
residual_block_1 (ResidualBl (None, 7, 7, 128)         98624     
_________________________________________________________________
pixel_conv_layer_3 (PixelCon (None, 7, 7, 128)         16512     
_________________________________________________________________
pixel_conv_layer_4 (PixelCon (None, 7, 7, 128)         16512     
_________________________________________________________________
conv2d_21 (Conv2D)           (None, 7, 7, 128)         16512     
=================================================================
总参数: 1,049,728
可训练参数: 1,049,728
不可训练参数: 0
_________________________________________________________________

准备数据以训练PixelCNN

我们将训练PixelCNN以学习离散编码的分类分布。 首先,我们将使用刚刚训练过的编码器和向量量化器生成代码索引。我们的训练目标是最小化这些索引与PixelCNN输出之间的交叉熵损失。在这里,类别的数量等于我们代码本中存在的嵌入数量(在我们的例子中为128)。PixelCNN模型被训练以学习一种分布(而不是最小化L1/L2损失),这就是它获取生成能力的来源。

# 生成代码本索引。
encoded_outputs = encoder.predict(x_train_scaled)
flat_enc_outputs = encoded_outputs.reshape(-1, encoded_outputs.shape[-1])
codebook_indices = quantizer.get_code_indices(flat_enc_outputs)

codebook_indices = codebook_indices.numpy().reshape(encoded_outputs.shape[:-1])
print(f"Shape of the training data for PixelCNN: {codebook_indices.shape}")
Shape of the training data for PixelCNN: (60000, 7, 7)

PixelCNN 训练

pixel_cnn.compile(
    optimizer=keras.optimizers.Adam(3e-4),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=["accuracy"],
)
pixel_cnn.fit(
    x=codebook_indices,
    y=codebook_indices,
    batch_size=128,
    epochs=30,
    validation_split=0.1,
)
Epoch 1/30
422/422 [==============================] - 4s 8ms/step - loss: 1.8550 - accuracy: 0.5959 - val_loss: 1.3127 - val_accuracy: 0.6268
Epoch 2/30
422/422 [==============================] - 3s 7ms/step - loss: 1.2207 - accuracy: 0.6402 - val_loss: 1.1722 - val_accuracy: 0.6482
Epoch 3/30
422/422 [==============================] - 3s 7ms/step - loss: 1.1412 - accuracy: 0.6536 - val_loss: 1.1313 - val_accuracy: 0.6552
Epoch 4/30
422/422 [==============================] - 3s 7ms/step - loss: 1.1060 - accuracy: 0.6601 - val_loss: 1.1058 - val_accuracy: 0.6596
Epoch 5/30
422/422 [==============================] - 3s 7ms/step - loss: 1.0828 - accuracy: 0.6646 - val_loss: 1.1020 - val_accuracy: 0.6603
Epoch 6/30
422/422 [==============================] - 3s 7ms/step - loss: 1.0649 - accuracy: 0.6682 - val_loss: 1.0809 - val_accuracy: 0.6638
Epoch 7/30
422/422 [==============================] - 3s 7ms/step - loss: 1.0515 - accuracy: 0.6710 - val_loss: 1.0712 - val_accuracy: 0.6659
Epoch 8/30
422/422 [==============================] - 3s 7ms/step - loss: 1.0406 - accuracy: 0.6733 - val_loss: 1.0647 - val_accuracy: 0.6671
Epoch 9/30
422/422 [==============================] - 3s 7ms/step - loss: 1.0312 - accuracy: 0.6752 - val_loss: 1.0633 - val_accuracy: 0.6674
Epoch 10/30
422/422 [==============================] - 3s 7ms/step - loss: 1.0235 - accuracy: 0.6771 - val_loss: 1.0554 - val_accuracy: 0.6695
Epoch 11/30
422/422 [==============================] - 3s 7ms/step - loss: 1.0162 - accuracy: 0.6788 - val_loss: 1.0518 - val_accuracy: 0.6694
Epoch 12/30
422/422 [==============================] - 3s 7ms/step - loss: 1.0105 - accuracy: 0.6799 - val_loss: 1.0541 - val_accuracy: 0.6693
Epoch 13/30
422/422 [==============================] - 3s 7ms/step - loss: 1.0050 - accuracy: 0.6811 - val_loss: 1.0481 - val_accuracy: 0.6705
Epoch 14/30
422/422 [==============================] - 3s 7ms/step - loss: 1.0011 - accuracy: 0.6820 - val_loss: 1.0462 - val_accuracy: 0.6709
Epoch 15/30
422/422 [==============================] - 3s 7ms/step - loss: 0.9964 - accuracy: 0.6831 - val_loss: 1.0459 - val_accuracy: 0.6709
Epoch 16/30
422/422 [==============================] - 3s 7ms/step - loss: 0.9922 - accuracy: 0.6840 - val_loss: 1.0444 - val_accuracy: 0.6704
Epoch 17/30
422/422 [==============================] - 3s 7ms/step - loss: 0.9884 - accuracy: 0.6848 - val_loss: 1.0405 - val_accuracy: 0.6725
Epoch 18/30
422/422 [==============================] - 3s 7ms/step - loss: 0.9846 - accuracy: 0.6859 - val_loss: 1.0400 - val_accuracy: 0.6722
Epoch 19/30
422/422 [==============================] - 3s 7ms/step - loss: 0.9822 - accuracy: 0.6864 - val_loss: 1.0394 - val_accuracy: 0.6728
Epoch 20/30
422/422 [==============================] - 3s 7ms/step - loss: 0.9787 - accuracy: 0.6872 - val_loss: 1.0393 - val_accuracy: 0.6717
Epoch 21/30
422/422 [==============================] - 3s 7ms/step - loss: 0.9761 - accuracy: 0.6878 - val_loss: 1.0398 - val_accuracy: 0.6725
Epoch 22/30
422/422 [==============================] - 3s 7ms/step - loss: 0.9733 - accuracy: 0.6884 - val_loss: 1.0376 - val_accuracy: 0.6726
Epoch 23/30
422/422 [==============================] - 3s 7ms/step - loss: 0.9708 - accuracy: 0.6890 - val_loss: 1.0352 - val_accuracy: 0.6732
Epoch 24/30
422/422 [==============================] - 3s 7ms/step - loss: 0.9685 - accuracy: 0.6894 - val_loss: 1.0369 - val_accuracy: 0.6723
Epoch 25/30
422/422 [==============================] - 3s 7ms/step - loss: 0.9660 - accuracy: 0.6901 - val_loss: 1.0384 - val_accuracy: 0.6733
Epoch 26/30
422/422 [==============================] - 3s 7ms/step - loss: 0.9638 - accuracy: 0.6908 - val_loss: 1.0355 - val_accuracy: 0.6728
Epoch 27/30
422/422 [==============================] - 3s 7ms/step - loss: 0.9619 - accuracy: 0.6912 - val_loss: 1.0325 - val_accuracy: 0.6739
Epoch 28/30
422/422 [==============================] - 3s 7ms/step - loss: 0.9594 - accuracy: 0.6917 - val_loss: 1.0334 - val_accuracy: 0.6736
Epoch 29/30
422/422 [==============================] - 3s 7ms/step - loss: 0.9582 - accuracy: 0.6920 - val_loss: 1.0366 - val_accuracy: 0.6733
Epoch 30/30
422/422 [==============================] - 3s 7ms/step - loss: 0.9561 - accuracy: 0.6926 - val_loss: 1.0336 - val_accuracy: 0.6728

<tensorflow.python.keras.callbacks.History at 0x7f95838ef750>

我们可以通过更多的训练和超参数调整来改善这些分数。


代码字典采样

现在我们的 PixelCNN 已经训练完成,我们可以从其输出中采样不同的代码,并将它们传递给我们的解码器以生成新颖的图像。

# 创建一个迷你采样模型。
inputs = layers.Input(shape=pixel_cnn.input_shape[1:])
outputs = pixel_cnn(inputs, training=False)
categorical_layer = tfp.layers.DistributionLambda(tfp.distributions.Categorical)
outputs = categorical_layer(outputs)
sampler = keras.Model(inputs, outputs)

我们现在构建一个先验来生成图像。在这里,我们将生成 10 张图像。

# 创建一个空的先验数组。
batch = 10
priors = np.zeros(shape=(batch,) + (pixel_cnn.input_shape)[1:])
batch, rows, cols = priors.shape

# 逐行逐列迭代先验,因为生成必须逐像素顺序进行。
for row in range(rows):
    for col in range(cols):
        # 喂入整个数组并获取下一个像素的像素值概率。
        probs = sampler.predict(priors)
        # 使用概率选择像素值并将这些值附加到先验中。
        priors[:, row, col] = probs[:, row, col]

print(f"Prior shape: {priors.shape}")
先验形状: (10, 7, 7)

我们现在可以使用我们的解码器生成图像。

# 执行嵌入查找。
pretrained_embeddings = quantizer.embeddings
priors_ohe = tf.one_hot(priors.astype("int32"), vqvae_trainer.num_embeddings).numpy()
quantized = tf.matmul(
    priors_ohe.astype("float32"), pretrained_embeddings, transpose_b=True
)
quantized = tf.reshape(quantized, (-1, *(encoded_outputs.shape[1:])))

# 生成新颖图像。
decoder = vqvae_trainer.vqvae.get_layer("decoder")
generated_samples = decoder.predict(quantized)

for i in range(batch):
    plt.subplot(1, 2, 1)
    plt.imshow(priors[i])
    plt.title("代码")
    plt.axis("off")

    plt.subplot(1, 2, 2)
    plt.imshow(generated_samples[i].squeeze() + 0.5)
    plt.title("生成样本")
    plt.axis("off")
    plt.show()

png

png

png

png

png

png

png

png

png

png

我们可以通过调整 PixelCNN 来提高这些生成样本的质量。


附加说明

  • 在 VQ-VAE 论文最初发布后,作者开发了一种指数滑动平均方案,以更新量化器内部的嵌入。如果您感兴趣,可以查看 这个代码片段
  • 为了进一步提高生成样本的质量,提出了 VQ-VAE-2,它采用级联方法学习代码字典并生成图像。