作者: ADMoreau
创建日期: 2020/05/17
最后修改: 2020/05/23
描述: 使用 Keras 实现的 PixelCNN。
PixelCNN 是一个生成模型,由 van den Oord 等人在 2016 年提出。 (参考文献: Conditional Image Generation with PixelCNN Decoders) 它旨在通过一个输入向量迭代生成图像(或其他数据类型),其中前一个元素的概率分布决定后续元素的概率分布。在下面的示例中,图像以这种方式逐像素生成,通过一个仅查看先前生成的像素(起点为左上角)的掩码卷积核来生成后续像素。 在推理期间,网络的输出被用作概率分布,从中采样新像素值以生成新图像(在这里,使用 MNIST,像素值是黑色或白色)。
import numpy as np
import keras
from keras import layers
from keras import ops
from tqdm import tqdm
# 模型 / 数据参数
num_classes = 10
input_shape = (28, 28, 1)
n_residual_blocks = 5
# 数据,分为训练集和测试集
(x, _), (y, _) = keras.datasets.mnist.load_data()
# 将所有图像连接在一起
data = np.concatenate((x, y), axis=0)
# 将所有小于最大值 256 的 33% 的像素值四舍五入为 0
# 任何高于此值的都四舍五入为 1,以使所有值均为
# 0 或 1
data = np.where(data < (0.33 * 256), 0, 1)
data = data.astype(np.float32)
# 第一个层是 PixelCNN 层。这个层简单地
# 建立在 2D 卷积层之上,但包含掩码。
class PixelConvLayer(layers.Layer):
def __init__(self, mask_type, **kwargs):
super().__init__()
self.mask_type = mask_type
self.conv = layers.Conv2D(**kwargs)
def build(self, input_shape):
# 建立 conv2d 层以初始化内核变量
self.conv.build(input_shape)
# 使用初始化的内核创建掩码
kernel_shape = ops.shape(self.conv.kernel)
self.mask = np.zeros(shape=kernel_shape)
self.mask[: kernel_shape[0] // 2, ...] = 1.0
self.mask[kernel_shape[0] // 2, : kernel_shape[1] // 2, ...] = 1.0
if self.mask_type == "B":
self.mask[kernel_shape[0] // 2, kernel_shape[1] // 2, ...] = 1.0
def call(self, inputs):
self.conv.kernel.assign(self.conv.kernel * self.mask)
return self.conv(inputs)
# 接下来,我们构建我们的残差块层。
# 这只是一个普通的残差块,但基于 PixelConvLayer。
class ResidualBlock(keras.layers.Layer):
def __init__(self, filters, **kwargs):
super().__init__(**kwargs)
self.conv1 = keras.layers.Conv2D(
filters=filters, kernel_size=1, activation="relu"
)
self.pixel_conv = PixelConvLayer(
mask_type="B",
filters=filters // 2,
kernel_size=3,
activation="relu",
padding="same",
)
self.conv2 = keras.layers.Conv2D(
filters=filters, kernel_size=1, activation="relu"
)
def call(self, inputs):
x = self.conv1(inputs)
x = self.pixel_conv(x)
x = self.conv2(x)
return keras.layers.add([inputs, x])
inputs = keras.Input(shape=input_shape, batch_size=128)
x = PixelConvLayer(
mask_type="A", filters=128, kernel_size=7, activation="relu", padding="same"
)(inputs)
for _ in range(n_residual_blocks):
x = ResidualBlock(filters=128)(x)
for _ in range(2):
x = PixelConvLayer(
mask_type="B",
filters=128,
kernel_size=1,
strides=1,
activation="relu",
padding="valid",
)(x)
out = keras.layers.Conv2D(
filters=1, kernel_size=1, strides=1, activation="sigmoid", padding="valid"
)(x)
pixel_cnn = keras.Model(inputs, out)
adam = keras.optimizers.Adam(learning_rate=0.0005)
pixel_cnn.compile(optimizer=adam, loss="binary_crossentropy")
pixel_cnn.summary()
pixel_cnn.fit(
x=data, y=data, batch_size=128, epochs=50, validation_split=0.1, verbose=2
)
模型: "functional_1"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓ ┃ 层 (类型) ┃ 输出形状 ┃ 参数 # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩ │ 输入层 (InputLayer) │ (128, 28, 28, 1) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ 像素卷积层 │ (128, 28, 28, 128) │ 6,400 │ │ (PixelConvLayer) │ │ │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ 残差块 (ResidualBlock) │ (128, 28, 28, 128) │ 98,624 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ 残差块_1 │ (128, 28, 28, 128) │ 98,624 │ │ (ResidualBlock) │ │ │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ 残差块_2 │ (128, 28, 28, 128) │ 98,624 │ │ (ResidualBlock) │ │ │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ 残差块_3 │ (128, 28, 28, 128) │ 98,624 │ │ (ResidualBlock) │ │ │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ 残差块_4 │ (128, 28, 28, 128) │ 98,624 │ │ (ResidualBlock) │ │ │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ pixel_conv_layer_6 │ (128, 28, 28, 128) │ 16,512 │ │ (PixelConvLayer) │ │ │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ pixel_conv_layer_7 │ (128, 28, 28, 128) │ 16,512 │ │ (PixelConvLayer) │ │ │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv2d_18 (Conv2D) │ (128, 28, 28, 1) │ 129 │ └─────────────────────────────────┴───────────────────────────┴────────────┘
总参数: 532,673 (2.03 MB)
可训练参数: 532,673 (2.03 MB)
非可训练参数: 0 (0.00 B)
Epoch 1/50
493/493 - 26s - 53ms/step - loss: 0.1137 - val_loss: 0.0933
Epoch 2/50
493/493 - 14s - 29ms/step - loss: 0.0915 - val_loss: 0.0901
Epoch 3/50
493/493 - 14s - 29ms/step - loss: 0.0893 - val_loss: 0.0888
Epoch 4/50
493/493 - 14s - 29ms/step - loss: 0.0882 - val_loss: 0.0880
Epoch 5/50
493/493 - 14s - 29ms/step - loss: 0.0874 - val_loss: 0.0870
Epoch 6/50
493/493 - 14s - 29ms/step - loss: 0.0867 - val_loss: 0.0867
Epoch 7/50
493/493 - 14s - 29ms/step - loss: 0.0863 - val_loss: 0.0867
Epoch 8/50
493/493 - 14s - 29ms/step - loss: 0.0859 - val_loss: 0.0860
Epoch 9/50
493/493 - 14s - 29ms/step - loss: 0.0855 - val_loss: 0.0856
Epoch 10/50
493/493 - 14s - 29ms/step - loss: 0.0853 - val_loss: 0.0861
Epoch 11/50
493/493 - 14s - 29ms/step - loss: 0.0850 - val_loss: 0.0860
Epoch 12/50
493/493 - 14s - 29ms/step - loss: 0.0847 - val_loss: 0.0873
Epoch 13/50
493/493 - 14s - 29ms/step - loss: 0.0846 - val_loss: 0.0852
Epoch 14/50
493/493 - 14s - 29ms/step - loss: 0.0844 - val_loss: 0.0846
Epoch 15/50
493/493 - 14s - 29ms/step - loss: 0.0842 - val_loss: 0.0848
Epoch 16/50
493/493 - 14s - 29ms/step - loss: 0.0840 - val_loss: 0.0843
Epoch 17/50
493/493 - 14s - 29ms/step - loss: 0.0838 - val_loss: 0.0847
Epoch 18/50
493/493 - 14s - 29ms/step - loss: 0.0837 - val_loss: 0.0841
Epoch 19/50
493/493 - 14s - 29ms/step - loss: 0.0835 - val_loss: 0.0842
Epoch 20/50
493/493 - 14s - 29ms/step - loss: 0.0834 - val_loss: 0.0844
Epoch 21/50
493/493 - 14s - 29ms/step - loss: 0.0834 - val_loss: 0.0843
Epoch 22/50
493/493 - 14s - 29ms/step - loss: 0.0832 - val_loss: 0.0838
Epoch 23/50
493/493 - 14s - 29ms/step - loss: 0.0831 - val_loss: 0.0840
Epoch 24/50
493/493 - 14s - 29ms/step - loss: 0.0830 - val_loss: 0.0841
Epoch 25/50
493/493 - 14s - 29ms/step - loss: 0.0829 - val_loss: 0.0837
Epoch 26/50
493/493 - 14s - 29ms/step - loss: 0.0828 - val_loss: 0.0837
Epoch 27/50
493/493 - 14s - 29ms/step - loss: 0.0827 - val_loss: 0.0836
Epoch 28/50
493/493 - 14s - 29ms/step - loss: 0.0827 - val_loss: 0.0836
Epoch 29/50
493/493 - 14s - 29ms/step - loss: 0.0825 - val_loss: 0.0838
Epoch 30/50
493/493 - 14s - 29ms/step - loss: 0.0825 - val_loss: 0.0834
Epoch 31/50
493/493 - 14s - 29ms/step - loss: 0.0824 - val_loss: 0.0832
Epoch 32/50
493/493 - 14s - 29ms/step - loss: 0.0823 - val_loss: 0.0833
Epoch 33/50
493/493 - 14s - 29ms/step - loss: 0.0822 - val_loss: 0.0836
Epoch 34/50
493/493 - 14s - 29ms/step - loss: 0.0822 - val_loss: 0.0832
Epoch 35/50
493/493 - 14s - 29ms/step - loss: 0.0821 - val_loss: 0.0832
Epoch 36/50
493/493 - 14s - 29ms/step - loss: 0.0820 - val_loss: 0.0835
Epoch 37/50
493/493 - 14s - 29ms/step - loss: 0.0820 - val_loss: 0.0834
Epoch 38/50
493/493 - 14s - 29ms/step - loss: 0.0819 - val_loss: 0.0833
Epoch 39/50
493/493 - 14s - 29ms/step - loss: 0.0818 - val_loss: 0.0832
Epoch 40/50
493/493 - 14s - 29ms/step - loss: 0.0818 - val_loss: 0.0834
Epoch 41/50
493/493 - 14s - 29ms/step - loss: 0.0817 - val_loss: 0.0832
Epoch 42/50
493/493 - 14s - 29ms/step - loss: 0.0816 - val_loss: 0.0834
Epoch 43/50
493/493 - 14s - 29ms/step - loss: 0.0816 - val_loss: 0.0839
Epoch 44/50
493/493 - 14s - 29ms/step - loss: 0.0815 - val_loss: 0.0831
Epoch 45/50
493/493 - 14s - 29ms/step - loss: 0.0815 - val_loss: 0.0832
Epoch 46/50
493/493 - 14s - 29ms/step - loss: 0.0814 - val_loss: 0.0835
Epoch 47/50
493/493 - 14s - 29ms/step - loss: 0.0814 - val_loss: 0.0830
Epoch 48/50
493/493 - 14s - 29ms/step - loss: 0.0813 - val_loss: 0.0832
Epoch 49/50
493/493 - 14s - 29ms/step - loss: 0.0812 - val_loss: 0.0833
Epoch 50/50
493/493 - 14s - 29ms/step - loss: 0.0812 - val_loss: 0.0831
<keras.src.callbacks.history.History at 0x7f45e6d78760>
PixelCNN不能一次生成完整图像。相反,它必须按顺序生成每个像素,将最后生成的像素附加到当前图像中,并将图像反馈给模型以重复该过程。
from IPython.display import Image, display
# 创建一个空的像素数组。
batch = 4
pixels = np.zeros(shape=(batch,) + (pixel_cnn.input_shape)[1:])
batch, rows, cols, channels = pixels.shape
# 遍历像素,因为生成必须按像素逐个顺序进行。
for row in tqdm(range(rows)):
for col in range(cols):
for channel in range(channels):
# 输入整个数组并获取下一个像素的像素值概率。
probs = pixel_cnn.predict(pixels)[:, row, col, channel]
# 使用概率选择像素值并将其附加到图像框架中。
pixels[:, row, col, channel] = ops.ceil(
probs - keras.random.uniform(probs.shape)
)
def deprocess_image(x):
# 将单通道黑白图像堆叠为rgb值。
x = np.stack((x, x, x), 2)
# 取消预处理
x *= 255.0
# 转换为uint8并裁剪到有效范围[0, 255]
x = np.clip(x, 0, 255).astype("uint8")
return x
# 遍历生成的图像并使用matplotlib绘制它们。
for i, pic in enumerate(pixels):
keras.utils.save_img(
"generated_image_{}.png".format(i), deprocess_image(np.squeeze(pic, -1))
)
display(Image("generated_image_0.png"))
display(Image("generated_image_1.png"))
display(Image("generated_image_2.png"))
display(Image("generated_image_3.png"))
100%|███████████████████████████████████████████████████████████████████████████| 28/28 [00:06<00:00, 4.51it/s]