作者: A_K_Nain
创建日期: 2020/08/12
最后修改: 2020/08/12
描述: CycleGAN的实现。
CycleGAN是一个旨在解决图像到图像转换问题的模型。图像到图像转换问题的目标是使用一组对齐的图像对来学习一个输入图像和输出图像之间的映射。然而,获取配对示例并不总是可行的。CycleGAN尝试在不需要配对输入输出图像的情况下学习这种映射,使用循环一致的对抗网络。
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa
import tensorflow_datasets as tfds
tfds.disable_progress_bar()
autotune = tf.data.AUTOTUNE
在这个示例中,我们将使用 马到斑马 数据集。
# 使用tensorflow-datasets加载马-斑马数据集。
dataset, _ = tfds.load("cycle_gan/horse2zebra", with_info=True, as_supervised=True)
train_horses, train_zebras = dataset["trainA"], dataset["trainB"]
test_horses, test_zebras = dataset["testA"], dataset["testB"]
# 定义标准图像大小。
orig_img_size = (286, 286)
# 训练期间使用的随机裁剪大小。
input_img_size = (256, 256, 3)
# 层的权重初始化器。
kernel_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)
# 实例归一化的Gamma初始化器。
gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)
buffer_size = 256
batch_size = 1
def normalize_img(img):
img = tf.cast(img, dtype=tf.float32)
# 将值映射到范围[-1, 1]
return (img / 127.5) - 1.0
def preprocess_train_image(img, label):
# 随机翻转
img = tf.image.random_flip_left_right(img)
# 首先调整为原始大小
img = tf.image.resize(img, [*orig_img_size])
# 随机裁剪到256X256
img = tf.image.random_crop(img, size=[*input_img_size])
# 将像素值归一化到范围[-1, 1]
img = normalize_img(img)
return img
def preprocess_test_image(img, label):
# 测试图像只进行调整大小和归一化。
img = tf.image.resize(img, [input_img_size[0], input_img_size[1]])
img = normalize_img(img)
return img
Dataset
对象# 将预处理操作应用于训练数据
train_horses = (
train_horses.map(preprocess_train_image, num_parallel_calls=autotune)
.cache()
.shuffle(buffer_size)
.batch(batch_size)
)
train_zebras = (
train_zebras.map(preprocess_train_image, num_parallel_calls=autotune)
.cache()
.shuffle(buffer_size)
.batch(batch_size)
)
# 将预处理操作应用于测试数据
test_horses = (
test_horses.map(preprocess_test_image, num_parallel_calls=autotune)
.cache()
.shuffle(buffer_size)
.batch(batch_size)
)
test_zebras = (
test_zebras.map(preprocess_test_image, num_parallel_calls=autotune)
.cache()
.shuffle(buffer_size)
.batch(batch_size)
)
_, ax = plt.subplots(4, 2, figsize=(10, 15))
for i, samples in enumerate(zip(train_horses.take(4), train_zebras.take(4))):
horse = (((samples[0][0] * 127.5) + 127.5).numpy()).astype(np.uint8)
zebra = (((samples[1][0] * 127.5) + 127.5).numpy()).astype(np.uint8)
ax[i, 0].imshow(horse)
ax[i, 1].imshow(zebra)
plt.show()
class ReflectionPadding2D(layers.Layer):
"""实现反射填充作为一个层。
参数:
padding(tuple): 空间维度的填充量。
返回:
一个与输入张量相同类型的填充张量。
"""
def __init__(self, padding=(1, 1), **kwargs):
self.padding = tuple(padding)
super().__init__(**kwargs)
def call(self, input_tensor, mask=None):
padding_width, padding_height = self.padding
padding_tensor = [
[0, 0],
[padding_height, padding_height],
[padding_width, padding_width],
[0, 0],
]
return tf.pad(input_tensor, padding_tensor, mode="REFLECT")
def residual_block(
x,
activation,
kernel_initializer=kernel_init,
kernel_size=(3, 3),
strides=(1, 1),
padding="valid",
gamma_initializer=gamma_init,
use_bias=False,
):
dim = x.shape[-1]
input_tensor = x
x = ReflectionPadding2D()(input_tensor)
x = layers.Conv2D(
dim,
kernel_size,
strides=strides,
kernel_initializer=kernel_initializer,
padding=padding,
use_bias=use_bias,
)(x)
x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
x = activation(x)
x = ReflectionPadding2D()(x)
x = layers.Conv2D(
dim,
kernel_size,
strides=strides,
kernel_initializer=kernel_initializer,
padding=padding,
use_bias=use_bias,
)(x)
x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
x = layers.add([input_tensor, x])
return x
def downsample(
x,
filters,
activation,
kernel_initializer=kernel_init,
kernel_size=(3, 3),
strides=(2, 2),
padding="same",
gamma_initializer=gamma_init,
use_bias=False,
):
x = layers.Conv2D(
filters,
kernel_size,
strides=strides,
kernel_initializer=kernel_initializer,
padding=padding,
use_bias=use_bias,
)(x)
x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
if activation:
x = activation(x)
return x
def upsample(
x,
filters,
activation,
kernel_size=(3, 3),
strides=(2, 2),
padding="same",
kernel_initializer=kernel_init,
gamma_initializer=gamma_init,
use_bias=False,
):
x = layers.Conv2DTranspose(
filters,
kernel_size,
strides=strides,
padding=padding,
kernel_initializer=kernel_initializer,
use_bias=use_bias,
)(x)
x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
if activation:
x = activation(x)
return x
生成器由下采样块组成:9个残差块和上采样块。生成器的结构如下:
c7s1-64 ==> 卷积块,`relu`激活,过滤器大小为7
d128 ====|
|-> 2个下采样块
d256 ====|
R256 ====|
R256 |
R256 |
R256 |
R256 |-> 9个残差块
R256 |
R256 |
R256 |
R256 ====|
u128 ====|
|-> 2个上采样块
u64 ====|
c7s1-3 => 最后一个卷积块,`tanh`激活,过滤器大小为7。
def get_resnet_generator(
filters=64,
num_downsampling_blocks=2,
num_residual_blocks=9,
num_upsample_blocks=2,
gamma_initializer=gamma_init,
name=None,
):
img_input = layers.Input(shape=input_img_size, name=name + "_img_input")
x = ReflectionPadding2D(padding=(3, 3))(img_input)
x = layers.Conv2D(filters, (7, 7), kernel_initializer=kernel_init, use_bias=False)(
x
)
x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
x = layers.Activation("relu")(x)
# 下采样
for _ in range(num_downsampling_blocks):
filters *= 2
x = downsample(x, filters=filters, activation=layers.Activation("relu"))
# 残差块
for _ in range(num_residual_blocks):
x = residual_block(x, activation=layers.Activation("relu"))
# 上采样
for _ in range(num_upsample_blocks):
filters //= 2
x = upsample(x, filters, activation=layers.Activation("relu"))
# 最后一个块
x = ReflectionPadding2D(padding=(3, 3))(x)
x = layers.Conv2D(3, (7, 7), padding="valid")(x)
x = layers.Activation("tanh")(x)
model = keras.models.Model(img_input, x, name=name)
return model
判别器实现以下架构:
C64->C128->C256->C512
def get_discriminator(
filters=64, kernel_initializer=kernel_init, num_downsampling=3, name=None
):
img_input = layers.Input(shape=input_img_size, name=name + "_img_input")
x = layers.Conv2D(
filters,
(4, 4),
strides=(2, 2),
padding="same",
kernel_initializer=kernel_initializer,
)(img_input)
x = layers.LeakyReLU(0.2)(x)
num_filters = filters
for num_downsample_block in range(3):
num_filters *= 2
if num_downsample_block < 2:
x = downsample(
x,
filters=num_filters,
activation=layers.LeakyReLU(0.2),
kernel_size=(4, 4),
strides=(2, 2),
)
else:
x = downsample(
x,
filters=num_filters,
activation=layers.LeakyReLU(0.2),
kernel_size=(4, 4),
strides=(1, 1),
)
x = layers.Conv2D(
1, (4, 4), strides=(1, 1), padding="same", kernel_initializer=kernel_initializer
)(x)
model = keras.models.Model(inputs=img_input, outputs=x, name=name)
return model
# 获取生成器
gen_G = get_resnet_generator(name="generator_G")
gen_F = get_resnet_generator(name="generator_F")
# 获取判别器
disc_X = get_discriminator(name="discriminator_X")
disc_Y = get_discriminator(name="discriminator_Y")
我们将重写Model
类的train_step()
方法以通过fit()
进行训练。
class CycleGan(keras.Model):
def __init__(
self,
generator_G,
generator_F,
discriminator_X,
discriminator_Y,
lambda_cycle=10.0,
lambda_identity=0.5,
):
super().__init__()
self.gen_G = generator_G
self.gen_F = generator_F
self.disc_X = discriminator_X
self.disc_Y = discriminator_Y
self.lambda_cycle = lambda_cycle
self.lambda_identity = lambda_identity
def compile(
self,
gen_G_optimizer,
gen_F_optimizer,
disc_X_optimizer,
disc_Y_optimizer,
gen_loss_fn,
disc_loss_fn,
):
super().compile()
self.gen_G_optimizer = gen_G_optimizer
self.gen_F_optimizer = gen_F_optimizer
self.disc_X_optimizer = disc_X_optimizer
self.disc_Y_optimizer = disc_Y_optimizer
self.generator_loss_fn = gen_loss_fn
self.discriminator_loss_fn = disc_loss_fn
self.cycle_loss_fn = keras.losses.MeanAbsoluteError()
self.identity_loss_fn = keras.losses.MeanAbsoluteError()
def train_step(self, batch_data):
# x 是马,y 是斑马
real_x, real_y = batch_data
# 对于 CycleGAN,我们需要为生成器和判别器计算不同的
# 损失类型。我们将在这里执行以下步骤:
#
# 1. 将真实图像通过生成器,获取生成的图像
# 2. 将生成的图像再传回生成器,以检查我们是否
# 可以从生成的图像预测出原始图像。
# 3. 使用生成器对真实图像进行身份映射。
# 4. 将步骤 1) 中生成的图像传递给相应的判别器。
# 5. 计算生成器的总损失(对抗损失 + 循环损失 + 身份损失)
# 6. 计算判别器的损失
# 7. 更新生成器的权重
# 8. 更新判别器的权重
# 9. 在字典中返回损失
with tf.GradientTape(persistent=True) as tape:
# 马到假斑马
fake_y = self.gen_G(real_x, training=True)
# 斑马到假马 -> y2x
fake_x = self.gen_F(real_y, training=True)
# 循环(马到假斑马再到假马):x -> y -> x
cycled_x = self.gen_F(fake_y, training=True)
# 循环(斑马到假马再到假斑马):y -> x -> y
cycled_y = self.gen_G(fake_x, training=True)
# 身份映射
same_x = self.gen_F(real_x, training=True)
same_y = self.gen_G(real_y, training=True)
# 判别器输出
disc_real_x = self.disc_X(real_x, training=True)
disc_fake_x = self.disc_X(fake_x, training=True)
disc_real_y = self.disc_Y(real_y, training=True)
disc_fake_y = self.disc_Y(fake_y, training=True)
# 生成器对抗损失
gen_G_loss = self.generator_loss_fn(disc_fake_y)
gen_F_loss = self.generator_loss_fn(disc_fake_x)
# 生成器循环损失
cycle_loss_G = self.cycle_loss_fn(real_y, cycled_y) * self.lambda_cycle
cycle_loss_F = self.cycle_loss_fn(real_x, cycled_x) * self.lambda_cycle
# 生成器身份损失
id_loss_G = (
self.identity_loss_fn(real_y, same_y)
* self.lambda_cycle
* self.lambda_identity
)
id_loss_F = (
self.identity_loss_fn(real_x, same_x)
* self.lambda_cycle
* self.lambda_identity
)
# 总生成器损失
total_loss_G = gen_G_loss + cycle_loss_G + id_loss_G
total_loss_F = gen_F_loss + cycle_loss_F + id_loss_F
# 判别器损失
disc_X_loss = self.discriminator_loss_fn(disc_real_x, disc_fake_x)
disc_Y_loss = self.discriminator_loss_fn(disc_real_y, disc_fake_y)
# 获取生成器的梯度
grads_G = tape.gradient(total_loss_G, self.gen_G.trainable_variables)
grads_F = tape.gradient(total_loss_F, self.gen_F.trainable_variables)
# 获取判别器的梯度
disc_X_grads = tape.gradient(disc_X_loss, self.disc_X.trainable_variables)
disc_Y_grads = tape.gradient(disc_Y_loss, self.disc_Y.trainable_variables)
# 更新生成器的权重
self.gen_G_optimizer.apply_gradients(
zip(grads_G, self.gen_G.trainable_variables)
)
self.gen_F_optimizer.apply_gradients(
zip(grads_F, self.gen_F.trainable_variables)
)
# 更新判别器的权重
self.disc_X_optimizer.apply_gradients(
zip(disc_X_grads, self.disc_X.trainable_variables)
)
self.disc_Y_optimizer.apply_gradients(
zip(disc_Y_grads, self.disc_Y.trainable_variables)
)
return {
"G_loss": total_loss_G,
"F_loss": total_loss_F,
"D_X_loss": disc_X_loss,
"D_Y_loss": disc_Y_loss,
}
class GANMonitor(keras.callbacks.Callback):
"""一个回调函数,用于在每个纪元后生成和保存图像"""
def __init__(self, num_img=4):
self.num_img = num_img
def on_epoch_end(self, epoch, logs=None):
_, ax = plt.subplots(4, 2, figsize=(12, 12))
for i, img in enumerate(test_horses.take(self.num_img)):
prediction = self.model.gen_G(img)[0].numpy()
prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
img = (img[0] * 127.5 + 127.5).numpy().astype(np.uint8)
ax[i, 0].imshow(img)
ax[i, 1].imshow(prediction)
ax[i, 0].set_title("输入图像")
ax[i, 1].set_title("转换图像")
ax[i, 0].axis("off")
ax[i, 1].axis("off")
prediction = keras.utils.array_to_img(prediction)
prediction.save(
"generated_img_{i}_{epoch}.png".format(i=i, epoch=epoch + 1)
)
plt.show()
plt.close()
# 用于评估对抗损失的损失函数
adv_loss_fn = keras.losses.MeanSquaredError()
# 定义生成器的损失函数
def generator_loss_fn(fake):
fake_loss = adv_loss_fn(tf.ones_like(fake), fake)
return fake_loss
# 定义鉴别器的损失函数
def discriminator_loss_fn(real, fake):
real_loss = adv_loss_fn(tf.ones_like(real), real)
fake_loss = adv_loss_fn(tf.zeros_like(fake), fake)
return (real_loss + fake_loss) * 0.5
# 创建循环生成对抗网络模型
cycle_gan_model = CycleGan(
generator_G=gen_G, generator_F=gen_F, discriminator_X=disc_X, discriminator_Y=disc_Y
)
# 编译模型
cycle_gan_model.compile(
gen_G_optimizer=keras.optimizers.legacy.Adam(learning_rate=2e-4, beta_1=0.5),
gen_F_optimizer=keras.optimizers.legacy.Adam(learning_rate=2e-4, beta_1=0.5),
disc_X_optimizer=keras.optimizers.legacy.Adam(learning_rate=2e-4, beta_1=0.5),
disc_Y_optimizer=keras.optimizers.legacy.Adam(learning_rate=2e-4, beta_1=0.5),
gen_loss_fn=generator_loss_fn,
disc_loss_fn=discriminator_loss_fn,
)
# 回调
plotter = GANMonitor()
checkpoint_filepath = "./model_checkpoints/cyclegan_checkpoints.{epoch:03d}"
model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
filepath=checkpoint_filepath, save_weights_only=True
)
# 这里我们只训练一个纪元,因为每个纪元大约需要
# 7分钟在单个P100机器上。
cycle_gan_model.fit(
tf.data.Dataset.zip((train_horses, train_zebras)),
epochs=1,
callbacks=[plotter, model_checkpoint_callback],
)
1067/1067 [==============================] - ETA: 0s - G_loss: 4.4794 - F_loss: 4.1048 - D_X_loss: 0.1584 - D_Y_loss: 0.1233
1067/1067 [==============================] - 390s 366ms/step - G_loss: 4.4783 - F_loss: 4.1035 - D_X_loss: 0.1584 - D_Y_loss: 0.1232
<tensorflow.python.keras.callbacks.History at 0x7f4184326e90>
测试模型的性能。
您可以使用托管在 Hugging Face Hub 上的训练模型,并在 Hugging Face Spaces 上尝试演示。
# 该模型训练了90个纪元。我们将在这里加载这些权重
# 一旦权重加载完成,我们将从测试
# 数据中提取几个样本并检查模型的性能。
!curl -LO https://github.com/AakashKumarNain/CycleGAN_TF2/releases/download/v1.0/saved_checkpoints.zip
!unzip -qq saved_checkpoints.zip
# 加载检查点
weight_file = "./saved_checkpoints/cyclegan_checkpoints.090"
cycle_gan_model.load_weights(weight_file).expect_partial()
print("权重加载成功")
_, ax = plt.subplots(4, 2, figsize=(10, 15))
for i, img in enumerate(test_horses.take(4)):
prediction = cycle_gan_model.gen_G(img, training=False)[0].numpy()
prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
img = (img[0] * 127.5 + 127.5).numpy().astype(np.uint8)
ax[i, 0].imshow(img)
ax[i, 1].imshow(prediction)
ax[i, 0].set_title("输入图像")
ax[i, 0].set_title("输入图像")
ax[i, 1].set_title("转换图像")
ax[i, 0].axis("off")
ax[i, 1].axis("off")
prediction = keras.utils.array_to_img(prediction)
prediction.save("predicted_img_{i}.png".format(i=i))
plt.tight_layout()
plt.show()
% 总计 % 接收 % 传输 平均速度 时间 时间 时间 当前
下载 上传 总计 消耗 剩余 速度
100 634 100 634 0 0 2874 0 --:--:-- --:--:-- --:--:-- 2881
100 273M 100 273M 0 0 1736k 0 0:02:41 0:02:41 --:--:-- 2049k
权重加载成功