作者: Sayak Paul, Chansung Park
创建日期: 2023/02/01
最后修改: 2023/02/05
描述: 实现 DreamBooth.
在这个示例中,我们实现了 DreamBooth,一种微调技术,可以通过仅使用 3 - 5 张图像来将新的视觉概念教授给文本条件的扩散模型。DreamBooth 在 DreamBooth: Fine Tuning Text-to-Image Diffusion Models for Subject-Driven Generation 中由 Ruiz 等人提出。
从某种意义上说,DreamBooth 类似于 传统的微调文本条件扩散模型的方法 ,除了几个小注意事项。这个示例假设你对扩散模型及其微调方式有基本的了解。以下是一些参考示例,可能会帮助你快速熟悉:
首先,让我们安装最新版本的 KerasCV 和 TensorFlow。
!pip install -q -U keras_cv==0.6.0
!pip install -q -U tensorflow
如果你正在运行代码,请确保你使用的是至少具有 24GB VRAM 的 GPU。
import math
import keras_cv
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from imutils import paths
from tensorflow import keras
... 是非常灵活的。通过教 Stable Diffusion 你最喜欢的视觉概念,你可以
以有趣的方式重新构建物体:
生成底层视觉概念的艺术呈现:
还有许多其他应用。我们欢迎你在这方面查看原始的 DreamBooth 论文。
DreamBooth 使用一种被称为“先验保护”的技术,有意义地指导训练过程,以便微调后的模型仍然能够保留你引入的视觉概念的一些先前语义。要了解更多关于“先验保护”的想法,请参考 这份文档。
在这里,我们需要介绍一些特定于 DreamBooth 的关键术语:
在代码中,这个生成过程看起来非常简单:
from tqdm import tqdm
import numpy as np
import hashlib
import keras_cv
import PIL
import os
class_images_dir = "class-images"
os.makedirs(class_images_dir, exist_ok=True)
model = keras_cv.models.StableDiffusion(img_width=512, img_height=512, jit_compile=True)
class_prompt = "一张狗的照片"
num_imgs_to_generate = 200
for i in tqdm(range(num_imgs_to_generate)):
images = model.text_to_image(
class_prompt,
batch_size=3,
)
idx = np.random.choice(len(images))
selected_image = PIL.Image.fromarray(images[idx])
hash_image = hashlib.sha1(selected_image.tobytes()).hexdigest()
image_filename = os.path.join(class_images_dir, f"{hash_image}.jpg")
selected_image.save(image_filename)
为了保持此示例的运行时间短,示例的作者提前生成了一些类图像,使用了 这个笔记本。
注意,先验保留是DreamBooth中使用的可选技术,但几乎总是有助于提高生成图像的质量。
instance_images_root = tf.keras.utils.get_file(
origin="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/instance-images.tar.gz",
untar=True,
)
class_images_root = tf.keras.utils.get_file(
origin="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/class-images.tar.gz",
untar=True,
)
首先,让我们加载图像路径。
instance_image_paths = list(paths.list_images(instance_images_root))
class_image_paths = list(paths.list_images(class_images_root))
然后,我们从路径中加载图像。
def load_images(image_paths):
images = [np.array(keras.utils.load_img(path)) for path in image_paths]
return images
接下来,我们使用一个工具函数绘制加载的图像。
def plot_images(images, title=None):
plt.figure(figsize=(20, 20))
for i in range(len(images)):
ax = plt.subplot(1, len(images), i + 1)
if title is not None:
plt.title(title)
plt.imshow(images[i])
plt.axis("off")
实例图像:
plot_images(load_images(instance_image_paths[:5]))
类图像:
plot_images(load_images(class_image_paths[:5]))
数据集准备包括两个阶段:(1):准备标题,(2)处理图像。
# 由于我们正在使用先验保留,我们需要匹配使用的实例图像数量。
# 我们只需重复实例图像路径即可完成匹配。
new_instance_image_paths = []
for index in range(len(class_image_paths)):
instance_image = instance_image_paths[index % len(instance_image_paths)]
new_instance_image_paths.append(instance_image)
# 我们也只需为每张图像重复提示/标题。
unique_id = "sks"
class_label = "dog"
instance_prompt = f"a photo of {unique_id} {class_label}"
instance_prompts = [instance_prompt] * len(new_instance_image_paths)
class_prompt = f"a photo of {class_label}"
class_prompts = [class_prompt] * len(class_image_paths)
接下来,我们嵌入提示以节省计算资源。
import itertools
# 填充令牌和最大提示长度是特定于文本编码器的。
# 如果您使用不同的文本编码器,请确保相应更改。
padding_token = 49407
max_prompt_length = 77
# 加载分词器。
tokenizer = keras_cv.models.stable_diffusion.SimpleTokenizer()
# 分词和填充令牌的方法。
def process_text(caption):
tokens = tokenizer.encode(caption)
tokens = tokens + [padding_token] * (max_prompt_length - len(tokens))
return np.array(tokens)
# 将分词的标题合并到一个数组中。
tokenized_texts = np.empty(
(len(instance_prompts) + len(class_prompts), max_prompt_length)
)
for i, caption in enumerate(itertools.chain(instance_prompts, class_prompts)):
tokenized_texts[i] = process_text(caption)
# 我们还预先计算文本嵌入,以在训练期间节省一些内存。
POS_IDS = tf.convert_to_tensor([list(range(max_prompt_length))], dtype=tf.int32)
text_encoder = keras_cv.models.stable_diffusion.TextEncoder(max_prompt_length)
gpus = tf.config.list_logical_devices("GPU")
# 确保计算在GPU上进行。
# 注意,这在有GPU时会自动完成。
# 这个示例只是尝试展示您如何更明确地进行这项操作。
with tf.device(gpus[0].name):
embedded_text = text_encoder(
[tf.convert_to_tensor(tokenized_texts), POS_IDS], training=False
).numpy()
# 为确保text_encoder不会占用任何GPU空间。
del text_encoder
resolution = 512
auto = tf.data.AUTOTUNE
augmenter = keras.Sequential(
layers=[
keras_cv.layers.CenterCrop(resolution, resolution),
keras_cv.layers.RandomFlip(),
keras.layers.Rescaling(scale=1.0 / 127.5, offset=-1),
]
)
def process_image(image_path, tokenized_text):
image = tf.io.read_file(image_path)
image = tf.io.decode_png(image, 3)
image = tf.image.resize(image, (resolution, resolution))
return image, tokenized_text
def apply_augmentation(image_batch, embedded_tokens):
return augmenter(image_batch), embedded_tokens
def prepare_dict(instance_only=True):
def fn(image_batch, embedded_tokens):
if instance_only:
batch_dict = {
"instance_images": image_batch,
"instance_embedded_texts": embedded_tokens,
}
return batch_dict
else:
batch_dict = {
"class_images": image_batch,
"class_embedded_texts": embedded_tokens,
}
return batch_dict
return fn
def assemble_dataset(image_paths, embedded_texts, instance_only=True, batch_size=1):
dataset = tf.data.Dataset.from_tensor_slices((image_paths, embedded_texts))
dataset = dataset.map(process_image, num_parallel_calls=auto)
dataset = dataset.shuffle(5, reshuffle_each_iteration=True)
dataset = dataset.batch(batch_size)
dataset = dataset.map(apply_augmentation, num_parallel_calls=auto)
prepare_dict_fn = prepare_dict(instance_only=instance_only)
dataset = dataset.map(prepare_dict_fn, num_parallel_calls=auto)
return dataset
instance_dataset = assemble_dataset(
new_instance_image_paths,
embedded_text[: len(new_instance_image_paths)],
)
class_dataset = assemble_dataset(
class_image_paths,
embedded_text[len(new_instance_image_paths) :],
instance_only=False,
)
train_dataset = tf.data.Dataset.zip((instance_dataset, class_dataset))
数据集准备好后,让我们快速检查其中的内容。
sample_batch = next(iter(train_dataset))
print(sample_batch[0].keys(), sample_batch[1].keys())
for k in sample_batch[0]:
print(k, sample_batch[0][k].shape)
for k in sample_batch[1]:
print(k, sample_batch[1][k].shape)
dict_keys(['instance_images', 'instance_embedded_texts']) dict_keys(['class_images', 'class_embedded_texts'])
instance_images (1, 512, 512, 3)
instance_embedded_texts (1, 77, 768)
class_images (1, 512, 512, 3)
class_embedded_texts (1, 77, 768)
在训练过程中,我们使用这些键来收集图像和文本嵌入并相应地进行连接。
我们的 DreamBooth 训练循环主要受到 这段脚本 的启发,该脚本由 Hugging Face 的 Diffusers 团队提供。然而,有一个重要的区别需要注意。我们在这个示例中只微调 UNet(负责预测噪声的模型),而不微调文本编码器。如果您正在寻找一个也执行文本编码器额外微调的实现,请参考 这个仓库。
import tensorflow.experimental.numpy as tnp
class DreamBoothTrainer(tf.keras.Model):
# 参考:
# https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py
def __init__(
self,
diffusion_model,
vae,
noise_scheduler,
use_mixed_precision=False,
prior_loss_weight=1.0,
max_grad_norm=1.0,
**kwargs,
):
super().__init__(**kwargs)
self.diffusion_model = diffusion_model
self.vae = vae
self.noise_scheduler = noise_scheduler
self.prior_loss_weight = prior_loss_weight
self.max_grad_norm = max_grad_norm
self.use_mixed_precision = use_mixed_precision
self.vae.trainable = False
def train_step(self, inputs):
instance_batch = inputs[0]
class_batch = inputs[1]
instance_images = instance_batch["instance_images"]
instance_embedded_text = instance_batch["instance_embedded_texts"]
class_images = class_batch["class_images"]
class_embedded_text = class_batch["class_embedded_texts"]
images = tf.concat([instance_images, class_images], 0)
embedded_texts = tf.concat([instance_embedded_text, class_embedded_text], 0)
batch_size = tf.shape(images)[0]
with tf.GradientTape() as tape:
# 将图像投影到潜在空间并从中采样。
latents = self.sample_from_encoder_outputs(self.vae(images, training=False))
# 了解更多关于魔法数字的信息:
# https://keras.io/examples/generative/fine_tune_via_textual_inversion/
latents = latents * 0.18215
# 采样我们将添加到潜在变量的噪声。
noise = tf.random.normal(tf.shape(latents))
# 为每个图像采样一个随机时间步长。
timesteps = tnp.random.randint(
0, self.noise_scheduler.train_timesteps, (batch_size,)
)
# 根据每个时间步的噪声幅度向潜在变量添加噪声
# (这是正向扩散过程)。
noisy_latents = self.noise_scheduler.add_noise(
tf.cast(latents, noise.dtype), noise, timesteps
)
# 根据预测类型获取损失的目标
# 目前只取采样噪声。
target = noise # noise_schedule.predict_epsilon == True
# 预测噪声残差并计算损失。
timestep_embedding = tf.map_fn(
lambda t: self.get_timestep_embedding(t), timesteps, dtype=tf.float32
)
model_pred = self.diffusion_model(
[noisy_latents, timestep_embedding, embedded_texts], training=True
)
loss = self.compute_loss(target, model_pred)
if self.use_mixed_precision:
loss = self.optimizer.get_scaled_loss(loss)
# 更新扩散模型的参数。
trainable_vars = self.diffusion_model.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
if self.use_mixed_precision:
gradients = self.optimizer.get_unscaled_gradients(gradients)
gradients = [tf.clip_by_norm(g, self.max_grad_norm) for g in gradients]
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
return {m.name: m.result() for m in self.metrics}
def get_timestep_embedding(self, timestep, dim=320, max_period=10000):
half = dim // 2
log_max_period = tf.math.log(tf.cast(max_period, tf.float32))
freqs = tf.math.exp(
-log_max_period * tf.range(0, half, dtype=tf.float32) / half
)
args = tf.convert_to_tensor([timestep], dtype=tf.float32) * freqs
embedding = tf.concat([tf.math.cos(args), tf.math.sin(args)], 0)
return embedding
def sample_from_encoder_outputs(self, outputs):
mean, logvar = tf.split(outputs, 2, axis=-1)
logvar = tf.clip_by_value(logvar, -30.0, 20.0)
std = tf.exp(0.5 * logvar)
sample = tf.random.normal(tf.shape(mean), dtype=mean.dtype)
return mean + std * sample
def compute_loss(self, target, model_pred):
# 将噪声和模型预测分成两部分并分别计算损失
# 由于第一部分的输入是实例样本,第二部分是类别样本,因此我们按此进行分块。
model_pred, model_pred_prior = tf.split(
model_pred, num_or_size_splits=2, axis=0
)
target, target_prior = tf.split(target, num_or_size_splits=2, axis=0)
# 计算实例损失。
loss = self.compiled_loss(target, model_pred)
# 计算先验损失。
prior_loss = self.compiled_loss(target_prior, model_pred_prior)
# 将先验损失添加到实例损失中。
loss = loss + self.prior_loss_weight * prior_loss
return loss
def save_weights(self, filepath, overwrite=True, save_format=None, options=None):
# 重写此方法将允许我们直接使用 `ModelCheckpoint`
# 回调与此训练器类。在这种情况下,它将
# 仅检查 `diffusion_model`,因为这是我们在微调期间训练的内容。
self.diffusion_model.save_weights(
filepath=filepath,
overwrite=overwrite,
save_format=save_format,
options=options,
)
def load_weights(self, filepath, by_name=False, skip_mismatch=False, options=None):
# 同样重写 `load_weights()`,以便我们可以直接在
# 训练器类对象上调用它。
self.diffusion_model.load_weights(
filepath=filepath,
by_name=by_name,
skip_mismatch=skip_mismatch,
options=options,
)
# 如果没有使用具有张量核心的 GPU,请对此进行注释。
tf.keras.mixed_precision.set_global_policy("mixed_float16")
use_mp = True # 如果没有使用具有张量核心的 GPU,请将其设置为 False。
image_encoder = keras_cv.models.stable_diffusion.ImageEncoder()
dreambooth_trainer = DreamBoothTrainer(
diffusion_model=keras_cv.models.stable_diffusion.DiffusionModel(
resolution, resolution, max_prompt_length
),
# 从编码器中移除顶层,这样可以切断方差并仅返回均值。
vae=tf.keras.Model(
image_encoder.input,
image_encoder.layers[-2].output,
),
noise_scheduler=keras_cv.models.stable_diffusion.NoiseScheduler(),
use_mixed_precision=use_mp,
)
# 这些超参数来自 Hugging Face 的这个教程:
# https://github.com/huggingface/diffusers/tree/main/examples/dreambooth
learning_rate = 5e-6
beta_1, beta_2 = 0.9, 0.999
weight_decay = (1e-2,)
epsilon = 1e-08
optimizer = tf.keras.optimizers.experimental.AdamW(
learning_rate=learning_rate,
weight_decay=weight_decay,
beta_1=beta_1,
beta_2=beta_2,
epsilon=epsilon,
)
dreambooth_trainer.compile(optimizer=optimizer, loss="mse")
我们首先计算需要训练的轮次。
num_update_steps_per_epoch = train_dataset.cardinality()
max_train_steps = 800
epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
print(f"训练轮次为 {epochs}。")
训练轮次为 4。
然后我们开始训练!
ckpt_path = "dreambooth-unet.h5"
ckpt_callback = tf.keras.callbacks.ModelCheckpoint(
ckpt_path,
save_weights_only=True,
monitor="loss",
mode="min",
)
dreambooth_trainer.fit(train_dataset, epochs=epochs, callbacks=[ckpt_callback])
第 1 轮/共 4 轮
200/200 [==============================] - 301s 462ms/step - loss: 0.1203
第 2 轮/共 4 轮
200/200 [==============================] - 94s 469ms/step - loss: 0.1139
第 3 轮/共 4 轮
200/200 [==============================] - 94s 469ms/step - loss: 0.1016
第 4 轮/共 4 轮
200/200 [==============================] - 94s 469ms/step - loss: 0.1231
<keras.callbacks.History at 0x7f19726600a0>
我们进行了各种实验,使用了稍微修改过的示例版本。我们的实验基于 这个仓库,并受到 这篇博客文章的启发。
首先,让我们看看如何使用微调后的检查点进行推断。
# 初始化一个新的稳定扩散模型。
dreambooth_model = keras_cv.models.StableDiffusion(
img_width=resolution, img_height=resolution, jit_compile=True
)
dreambooth_model.diffusion_model.load_weights(ckpt_path)
# 请注意,唯一标识符和类别已在提示中使用。
prompt = f"A photo of {unique_id} {class_label} in a bucket"
num_imgs_to_gen = 3
images_dreamboothed = dreambooth_model.text_to_image(prompt, batch_size=num_imgs_to_gen)
plot_images(images_dreamboothed, prompt)
使用此模型检查点,您承认其使用受 CreativeML Open RAIL-M 许可证的条款约束,网址为 https://raw.githubusercontent.com/CompVis/stable-diffusion/main/LICENSE
50/50 [==============================] - 42s 160ms/step
现在,让我们加载另一个实验的检查点,在该实验中我们还微调了文本编码器以及 UNet:
unet_weights = tf.keras.utils.get_file(
origin="https://huggingface.co/chansung/dreambooth-dog/resolve/main/lr%409e-06-max_train_steps%40200-train_text_encoder%40True-unet.h5"
)
text_encoder_weights = tf.keras.utils.get_file(
origin="https://huggingface.co/chansung/dreambooth-dog/resolve/main/lr%409e-06-max_train_steps%40200-train_text_encoder%40True-text_encoder.h5"
)
dreambooth_model.diffusion_model.load_weights(unet_weights)
dreambooth_model.text_encoder.load_weights(text_encoder_weights)
images_dreamboothed = dreambooth_model.text_to_image(prompt, batch_size=num_imgs_to_gen)
plot_images(images_dreamboothed, prompt)
从 https://huggingface.co/chansung/dreambooth-dog/resolve/main/lr%409e-06-max_train_steps%40200-train_text_encoder%40True-unet.h5 下载数据
3439088208/3439088208 [==============================] - 67s 0us/step
从 https://huggingface.co/chansung/dreambooth-dog/resolve/main/lr%409e-06-max_train_steps%40200-train_text_encoder%40True-text_encoder.h5 下载数据
492466760/492466760 [==============================] - 9s 0us/step
50/50 [==============================] - 8s 159ms/step
text_to_image()
生成图像的默认步骤数是50。
让我们将其增加到100。
images_dreamboothed = dreambooth_model.text_to_image(
prompt, batch_size=num_imgs_to_gen, num_steps=100
)
plot_images(images_dreamboothed, prompt)
100/100 [==============================] - 16秒 159毫秒/步
随意尝试不同的提示(别忘了添加唯一标识符和类标签!)以查看结果如何变化。我们欢迎您查看我们的代码库和更多实验结果这里。您还可以阅读这篇博文以获取更多想法。