PixelCNN

作者: ADMoreau
创建日期: 2020/05/17
最后修改: 2020/05/23
描述: 使用 Keras 实现的 PixelCNN。

在 Colab 中查看 GitHub 源码


介绍

PixelCNN 是一个生成模型,由 van den Oord 等人在 2016 年提出。 (参考文献: Conditional Image Generation with PixelCNN Decoders) 它旨在通过一个输入向量迭代生成图像(或其他数据类型),其中前一个元素的概率分布决定后续元素的概率分布。在下面的示例中,图像以这种方式逐像素生成,通过一个仅查看先前生成的像素(起点为左上角)的掩码卷积核来生成后续像素。 在推理期间,网络的输出被用作概率分布,从中采样新像素值以生成新图像(在这里,使用 MNIST,像素值是黑色或白色)。

import numpy as np
import keras
from keras import layers
from keras import ops
from tqdm import tqdm

获取数据

# 模型 / 数据参数
num_classes = 10
input_shape = (28, 28, 1)
n_residual_blocks = 5
# 数据,分为训练集和测试集
(x, _), (y, _) = keras.datasets.mnist.load_data()
# 将所有图像连接在一起
data = np.concatenate((x, y), axis=0)
# 将所有小于最大值 256 的 33% 的像素值四舍五入为 0
# 任何高于此值的都四舍五入为 1,以使所有值均为
# 0 或 1
data = np.where(data < (0.33 * 256), 0, 1)
data = data.astype(np.float32)

为模型创建所需层的两个类

# 第一个层是 PixelCNN 层。这个层简单地
# 建立在 2D 卷积层之上,但包含掩码。
class PixelConvLayer(layers.Layer):
    def __init__(self, mask_type, **kwargs):
        super().__init__()
        self.mask_type = mask_type
        self.conv = layers.Conv2D(**kwargs)

    def build(self, input_shape):
        # 建立 conv2d 层以初始化内核变量
        self.conv.build(input_shape)
        # 使用初始化的内核创建掩码
        kernel_shape = ops.shape(self.conv.kernel)
        self.mask = np.zeros(shape=kernel_shape)
        self.mask[: kernel_shape[0] // 2, ...] = 1.0
        self.mask[kernel_shape[0] // 2, : kernel_shape[1] // 2, ...] = 1.0
        if self.mask_type == "B":
            self.mask[kernel_shape[0] // 2, kernel_shape[1] // 2, ...] = 1.0

    def call(self, inputs):
        self.conv.kernel.assign(self.conv.kernel * self.mask)
        return self.conv(inputs)


# 接下来,我们构建我们的残差块层。
# 这只是一个普通的残差块,但基于 PixelConvLayer。
class ResidualBlock(keras.layers.Layer):
    def __init__(self, filters, **kwargs):
        super().__init__(**kwargs)
        self.conv1 = keras.layers.Conv2D(
            filters=filters, kernel_size=1, activation="relu"
        )
        self.pixel_conv = PixelConvLayer(
            mask_type="B",
            filters=filters // 2,
            kernel_size=3,
            activation="relu",
            padding="same",
        )
        self.conv2 = keras.layers.Conv2D(
            filters=filters, kernel_size=1, activation="relu"
        )

    def call(self, inputs):
        x = self.conv1(inputs)
        x = self.pixel_conv(x)
        x = self.conv2(x)
        return keras.layers.add([inputs, x])

基于原始论文构建模型

inputs = keras.Input(shape=input_shape, batch_size=128)
x = PixelConvLayer(
    mask_type="A", filters=128, kernel_size=7, activation="relu", padding="same"
)(inputs)

for _ in range(n_residual_blocks):
    x = ResidualBlock(filters=128)(x)

for _ in range(2):
    x = PixelConvLayer(
        mask_type="B",
        filters=128,
        kernel_size=1,
        strides=1,
        activation="relu",
        padding="valid",
    )(x)

out = keras.layers.Conv2D(
    filters=1, kernel_size=1, strides=1, activation="sigmoid", padding="valid"
)(x)

pixel_cnn = keras.Model(inputs, out)
adam = keras.optimizers.Adam(learning_rate=0.0005)
pixel_cnn.compile(optimizer=adam, loss="binary_crossentropy")

pixel_cnn.summary()
pixel_cnn.fit(
    x=data, y=data, batch_size=128, epochs=50, validation_split=0.1, verbose=2
)
模型: "functional_1"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
┃ 层 (类型)                          输出形状                   参数 # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
│ 输入层 (InputLayer)        │ (128, 28, 28, 1)          │          0 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ 像素卷积层                     │ (128, 28, 28, 128)        │      6,400 │
│ (PixelConvLayer)                │                           │            │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ 残差块 (ResidualBlock)  │ (128, 28, 28, 128)        │     98,624 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ 残差块_1                      │ (128, 28, 28, 128)        │     98,624 │
│ (ResidualBlock)                 │                           │            │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ 残差块_2                      │ (128, 28, 28, 128)        │     98,624 │
│ (ResidualBlock)                 │                           │            │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ 残差块_3                      │ (128, 28, 28, 128)        │     98,624 │
│ (ResidualBlock)                 │                           │            │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ 残差块_4                      │ (128, 28, 28, 128)        │     98,624 │
│ (ResidualBlock)                 │                           │            │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ pixel_conv_layer_6              │ (128, 28, 28, 128)        │     16,512 │
│ (PixelConvLayer)                │                           │            │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ pixel_conv_layer_7              │ (128, 28, 28, 128)        │     16,512 │
│ (PixelConvLayer)                │                           │            │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ conv2d_18 (Conv2D)              │ (128, 28, 28, 1)          │        129 │
└─────────────────────────────────┴───────────────────────────┴────────────┘
 总参数: 532,673 (2.03 MB)
 可训练参数: 532,673 (2.03 MB)
 非可训练参数: 0 (0.00 B)
Epoch 1/50
493/493 - 26s - 53ms/step - loss: 0.1137 - val_loss: 0.0933
Epoch 2/50
493/493 - 14s - 29ms/step - loss: 0.0915 - val_loss: 0.0901
Epoch 3/50
493/493 - 14s - 29ms/step - loss: 0.0893 - val_loss: 0.0888
Epoch 4/50
493/493 - 14s - 29ms/step - loss: 0.0882 - val_loss: 0.0880
Epoch 5/50
493/493 - 14s - 29ms/step - loss: 0.0874 - val_loss: 0.0870
Epoch 6/50
493/493 - 14s - 29ms/step - loss: 0.0867 - val_loss: 0.0867
Epoch 7/50
493/493 - 14s - 29ms/step - loss: 0.0863 - val_loss: 0.0867
Epoch 8/50
493/493 - 14s - 29ms/step - loss: 0.0859 - val_loss: 0.0860
Epoch 9/50
493/493 - 14s - 29ms/step - loss: 0.0855 - val_loss: 0.0856
Epoch 10/50
493/493 - 14s - 29ms/step - loss: 0.0853 - val_loss: 0.0861
Epoch 11/50
493/493 - 14s - 29ms/step - loss: 0.0850 - val_loss: 0.0860
Epoch 12/50
493/493 - 14s - 29ms/step - loss: 0.0847 - val_loss: 0.0873
Epoch 13/50
493/493 - 14s - 29ms/step - loss: 0.0846 - val_loss: 0.0852
Epoch 14/50
493/493 - 14s - 29ms/step - loss: 0.0844 - val_loss: 0.0846
Epoch 15/50
493/493 - 14s - 29ms/step - loss: 0.0842 - val_loss: 0.0848
Epoch 16/50
493/493 - 14s - 29ms/step - loss: 0.0840 - val_loss: 0.0843
Epoch 17/50
493/493 - 14s - 29ms/step - loss: 0.0838 - val_loss: 0.0847
Epoch 18/50
493/493 - 14s - 29ms/step - loss: 0.0837 - val_loss: 0.0841
Epoch 19/50
493/493 - 14s - 29ms/step - loss: 0.0835 - val_loss: 0.0842
Epoch 20/50
493/493 - 14s - 29ms/step - loss: 0.0834 - val_loss: 0.0844
Epoch 21/50
493/493 - 14s - 29ms/step - loss: 0.0834 - val_loss: 0.0843
Epoch 22/50
493/493 - 14s - 29ms/step - loss: 0.0832 - val_loss: 0.0838
Epoch 23/50
493/493 - 14s - 29ms/step - loss: 0.0831 - val_loss: 0.0840
Epoch 24/50
493/493 - 14s - 29ms/step - loss: 0.0830 - val_loss: 0.0841
Epoch 25/50
493/493 - 14s - 29ms/step - loss: 0.0829 - val_loss: 0.0837
Epoch 26/50
493/493 - 14s - 29ms/step - loss: 0.0828 - val_loss: 0.0837
Epoch 27/50
493/493 - 14s - 29ms/step - loss: 0.0827 - val_loss: 0.0836
Epoch 28/50
493/493 - 14s - 29ms/step - loss: 0.0827 - val_loss: 0.0836
Epoch 29/50
493/493 - 14s - 29ms/step - loss: 0.0825 - val_loss: 0.0838
Epoch 30/50
493/493 - 14s - 29ms/step - loss: 0.0825 - val_loss: 0.0834
Epoch 31/50
493/493 - 14s - 29ms/step - loss: 0.0824 - val_loss: 0.0832
Epoch 32/50
493/493 - 14s - 29ms/step - loss: 0.0823 - val_loss: 0.0833
Epoch 33/50
493/493 - 14s - 29ms/step - loss: 0.0822 - val_loss: 0.0836
Epoch 34/50
493/493 - 14s - 29ms/step - loss: 0.0822 - val_loss: 0.0832
Epoch 35/50
493/493 - 14s - 29ms/step - loss: 0.0821 - val_loss: 0.0832
Epoch 36/50
493/493 - 14s - 29ms/step - loss: 0.0820 - val_loss: 0.0835
Epoch 37/50
493/493 - 14s - 29ms/step - loss: 0.0820 - val_loss: 0.0834
Epoch 38/50
493/493 - 14s - 29ms/step - loss: 0.0819 - val_loss: 0.0833
Epoch 39/50
493/493 - 14s - 29ms/step - loss: 0.0818 - val_loss: 0.0832
Epoch 40/50
493/493 - 14s - 29ms/step - loss: 0.0818 - val_loss: 0.0834
Epoch 41/50
493/493 - 14s - 29ms/step - loss: 0.0817 - val_loss: 0.0832
Epoch 42/50
493/493 - 14s - 29ms/step - loss: 0.0816 - val_loss: 0.0834
Epoch 43/50
493/493 - 14s - 29ms/step - loss: 0.0816 - val_loss: 0.0839
Epoch 44/50
493/493 - 14s - 29ms/step - loss: 0.0815 - val_loss: 0.0831
Epoch 45/50
493/493 - 14s - 29ms/step - loss: 0.0815 - val_loss: 0.0832
Epoch 46/50
493/493 - 14s - 29ms/step - loss: 0.0814 - val_loss: 0.0835
Epoch 47/50
493/493 - 14s - 29ms/step - loss: 0.0814 - val_loss: 0.0830
Epoch 48/50
493/493 - 14s - 29ms/step - loss: 0.0813 - val_loss: 0.0832
Epoch 49/50
493/493 - 14s - 29ms/step - loss: 0.0812 - val_loss: 0.0833
Epoch 50/50
493/493 - 14s - 29ms/step - loss: 0.0812 - val_loss: 0.0831

<keras.src.callbacks.history.History at 0x7f45e6d78760>

演示

PixelCNN不能一次生成完整图像。相反,它必须按顺序生成每个像素,将最后生成的像素附加到当前图像中,并将图像反馈给模型以重复该过程。

from IPython.display import Image, display

# 创建一个空的像素数组。
batch = 4
pixels = np.zeros(shape=(batch,) + (pixel_cnn.input_shape)[1:])
batch, rows, cols, channels = pixels.shape

# 遍历像素,因为生成必须按像素逐个顺序进行。
for row in tqdm(range(rows)):
    for col in range(cols):
        for channel in range(channels):
            # 输入整个数组并获取下一个像素的像素值概率。
            probs = pixel_cnn.predict(pixels)[:, row, col, channel]
            # 使用概率选择像素值并将其附加到图像框架中。
            pixels[:, row, col, channel] = ops.ceil(
                probs - keras.random.uniform(probs.shape)
            )


def deprocess_image(x):
    # 将单通道黑白图像堆叠为rgb值。
    x = np.stack((x, x, x), 2)
    # 取消预处理
    x *= 255.0
    # 转换为uint8并裁剪到有效范围[0, 255]
    x = np.clip(x, 0, 255).astype("uint8")
    return x


# 遍历生成的图像并使用matplotlib绘制它们。
for i, pic in enumerate(pixels):
    keras.utils.save_img(
        "generated_image_{}.png".format(i), deprocess_image(np.squeeze(pic, -1))
    )

display(Image("generated_image_0.png"))
display(Image("generated_image_1.png"))
display(Image("generated_image_2.png"))
display(Image("generated_image_3.png"))
100%|███████████████████████████████████████████████████████████████████████████| 28/28 [00:06<00:00,  4.51it/s]

png

png

png

png