代码示例 / 生成式深度学习 / 微调 Stable Diffusion

微调 Stable Diffusion

作者: Sayak PaulChansung Park
创建日期: 2022/12/28
最后修改日期: 2023/01/13
描述: 使用自定义图像-文本数据集微调 Stable Diffusion。

在 Colab 中查看 GitHub 源码


介绍

本教程演示如何在自定义的 {图像, 文本} 对数据集上微调 Stable Diffusion 模型。 我们在 Hugging Face 提供的微调脚本的基础上进行构建 在这里

我们假设您对 Stable Diffusion 模型有一定的了解。 如果您想获得更多信息,以下资源可能会对您有所帮助:

强烈建议您使用至少 30GB 内存的 GPU 来执行代码。

通过本指南的学习,您将能够生成有趣的宝可梦图像:

custom-pokemons

本教程依赖于 KerasCV 0.4.0。此外,使用 AdamW 进行混合精度计算时,我们至少需要 TensorFlow 2.11。

!pip install keras-cv==0.6.0 -q
!pip install -U tensorflow -q
!pip install keras-core -q

我们正在微调什么?

Stable Diffusion 模型可以分解为几个关键模型:

  • 一个文本编码器,用于将输入提示投影到潜在空间。(与图像相关的标题称为“提示”。)
  • 一个变分自编码器 (VAE),用于将输入图像投影到潜在空间,充当图像向量空间。
  • 一个扩散模型,细化潜在向量,并生成另一个潜在向量,依赖于编码的文本提示。
  • 一个解码器,根据扩散模型提供的潜在向量生成图像。

值得注意的是,在从文本提示生成图像的过程中,通常不使用图像编码器。

然而,在微调过程中,工作流程如下:

  1. 输入文本提示通过文本编码器投影到潜在空间。
  2. 输入图像通过 VAE 的图像编码器部分投影到潜在空间。
  3. 在给定的时间步长上,为图像潜在向量添加少量噪声。
  4. 扩散模型使用来自这两个空间的潜在向量以及时间步长嵌入来预测添加到图像潜在的噪声。
  5. 在步骤 3 中预测的噪声与原始噪声之间计算重构损失。
  6. 最后,使用梯度下降优化扩散模型参数,以减少该损失。

请注意,在微调过程中,仅更新扩散模型参数,而保持(预训练的)文本和图像编码器不变。

如果这听起来很复杂,不用担心。代码要简单得多!


导入

from textwrap import wrap
import os

import keras_cv
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow.experimental.numpy as tnp
from keras_cv.models.stable_diffusion.clip_tokenizer import SimpleTokenizer
from keras_cv.models.stable_diffusion.diffusion_model import DiffusionModel
from keras_cv.models.stable_diffusion.image_encoder import ImageEncoder
from keras_cv.models.stable_diffusion.noise_scheduler import NoiseScheduler
from keras_cv.models.stable_diffusion.text_encoder import TextEncoder
from tensorflow import keras

数据加载

我们使用数据集 宝可梦 BLIP 标题。 但是,我们将使用一个略有不同的版本,该版本源自原始数据集,以更好地与 tf.data 适配。有关更多详细信息,请参见 文档

data_path = tf.keras.utils.get_file(
    origin="https://huggingface.co/datasets/sayakpaul/pokemon-blip-original-version/resolve/main/pokemon_dataset.tar.gz",
    untar=True,
)

data_frame = pd.read_csv(os.path.join(data_path, "data.csv"))

data_frame["image_path"] = data_frame["image_path"].apply(
    lambda x: os.path.join(data_path, x)
)
data_frame.head()
image_path caption
0 /home/jupyter/.keras/datasets/pokemon_dataset/... 一只绿色的宝可梦,红色的眼睛的画
1 /home/jupyter/.keras/datasets/pokemon_dataset/... 一个绿色和黄色的玩具,红色的鼻子
2 /home/jupyter/.keras/datasets/pokemon_dataset/... 一颗红白相间的球,愤怒的表情
3 /home/jupyter/.keras/datasets/pokemon_dataset/... 一颗卡通球,上面有微笑的表情
4 /home/jupyter/.keras/datasets/pokemon_dataset/... 一堆球,上面画着脸

由于我们只有833个{image, caption}对,我们可以预计算文本嵌入。 此外,在微调过程中,文本编码器将保持冻结,因此我们可以通过这样做节省一些计算。

在使用文本编码器之前,我们需要对标题进行标记化。

# 填充标记和最大提示长度是特定于文本编码器的。
# 如果您使用不同的文本编码器,请确保相应地更改它们。
PADDING_TOKEN = 49407
MAX_PROMPT_LENGTH = 77

# 加载分词器。
tokenizer = SimpleTokenizer()

#  标记化和填充标记的方法。
def process_text(caption):
    tokens = tokenizer.encode(caption)
    tokens = tokens + [PADDING_TOKEN] * (MAX_PROMPT_LENGTH - len(tokens))
    return np.array(tokens)


# 将标记化的标题收集到一个数组中。
tokenized_texts = np.empty((len(data_frame), MAX_PROMPT_LENGTH))

all_captions = list(data_frame["caption"].values)
for i, caption in enumerate(all_captions):
    tokenized_texts[i] = process_text(caption)

准备一个 tf.data.Dataset

在这一部分,我们将从输入图像文件路径及其对应的标题标记准备一个 tf.data.Dataset 对象。 本节将包括以下内容:

  • 从标记化的标题预计算文本嵌入。
  • 加载和增强输入图像。
  • 数据集的打乱和批处理。
RESOLUTION = 256
AUTO = tf.data.AUTOTUNE
POS_IDS = tf.convert_to_tensor([list(range(MAX_PROMPT_LENGTH))], dtype=tf.int32)

augmenter = keras.Sequential(
    layers=[
        keras_cv.layers.CenterCrop(RESOLUTION, RESOLUTION),
        keras_cv.layers.RandomFlip(),
        tf.keras.layers.Rescaling(scale=1.0 / 127.5, offset=-1),
    ]
)
text_encoder = TextEncoder(MAX_PROMPT_LENGTH)


def process_image(image_path, tokenized_text):
    image = tf.io.read_file(image_path)
    image = tf.io.decode_png(image, 3)
    image = tf.image.resize(image, (RESOLUTION, RESOLUTION))
    return image, tokenized_text


def apply_augmentation(image_batch, token_batch):
    return augmenter(image_batch), token_batch


def run_text_encoder(image_batch, token_batch):
    return (
        image_batch,
        token_batch,
        text_encoder([token_batch, POS_IDS], training=False),
    )


def prepare_dict(image_batch, token_batch, encoded_text_batch):
    return {
        "images": image_batch,
        "tokens": token_batch,
        "encoded_text": encoded_text_batch,
    }


def prepare_dataset(image_paths, tokenized_texts, batch_size=1):
    dataset = tf.data.Dataset.from_tensor_slices((image_paths, tokenized_texts))
    dataset = dataset.shuffle(batch_size * 10)
    dataset = dataset.map(process_image, num_parallel_calls=AUTO).batch(batch_size)
    dataset = dataset.map(apply_augmentation, num_parallel_calls=AUTO)
    dataset = dataset.map(run_text_encoder, num_parallel_calls=AUTO)
    dataset = dataset.map(prepare_dict, num_parallel_calls=AUTO)
    return dataset.prefetch(AUTO)

基线稳定扩散模型是使用512x512分辨率的图像训练的。 使用较高分辨率图像训练的模型不太可能良好地迁移到较低分辨率图像上。 然而,当前模型如果保持分辨率为512x512(而不启用混合精度),将导致OOM。 因此,为了进行交互式演示,我们将输入分辨率保持为256x256。

# 准备数据集。
training_dataset = prepare_dataset(
    np.array(data_frame["image_path"]), tokenized_texts, batch_size=4
)

# 获取一个样本批次并进行检查。
sample_batch = next(iter(training_dataset))

for k in sample_batch:
    print(k, sample_batch[k].shape)
图像 (4, 256, 256, 3)
标记 (4, 77)
编码文本 (4, 77, 768)

我们还可以查看训练图像及其对应的标题。

plt.figure(figsize=(20, 10))

for i in range(3):
    ax = plt.subplot(1, 4, i + 1)
    plt.imshow((sample_batch["images"][i] + 1) / 2)

    text = tokenizer.decode(sample_batch["tokens"][i].numpy().squeeze())
    text = text.replace("<|startoftext|>", "")
    text = text.replace("<|endoftext|>", "")
    text = "\n".join(wrap(text, 12))
    plt.title(text, fontsize=15)

    plt.axis("off")

png


A trainer class for the fine-tuning loop

class Trainer(tf.keras.Model):
    # 参考:
    # https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py

    def __init__(
        self,
        diffusion_model,
        vae,
        noise_scheduler,
        use_mixed_precision=False,
        max_grad_norm=1.0,
        **kwargs
    ):
        super().__init__(**kwargs)

        self.diffusion_model = diffusion_model
        self.vae = vae
        self.noise_scheduler = noise_scheduler
        self.max_grad_norm = max_grad_norm

        self.use_mixed_precision = use_mixed_precision
        self.vae.trainable = False

    def train_step(self, inputs):
        images = inputs["images"]
        encoded_text = inputs["encoded_text"]
        batch_size = tf.shape(images)[0]

        with tf.GradientTape() as tape:
            # 将图像投影到潜在空间并从中采样。
            latents = self.sample_from_encoder_outputs(self.vae(images, training=False))
            # 了解有关此魔法数字的更多信息:
            # https://keras.io/examples/generative/fine_tune_via_textual_inversion/
            latents = latents * 0.18215

            # 采样我们将添加到潜在变量中的噪声。
            noise = tf.random.normal(tf.shape(latents))

            # 为每幅图像采样一个随机时间步。
            timesteps = tnp.random.randint(
                0, self.noise_scheduler.train_timesteps, (batch_size,)
            )

            # 根据每个时间步的噪声幅度将噪声添加到潜在变量中
            # (这是前向扩散过程)。
            noisy_latents = self.noise_scheduler.add_noise(
                tf.cast(latents, noise.dtype), noise, timesteps
            )

            # 根据预测类型获取损失的目标
            # 目前仅为采样噪声。
            target = noise  # noise_schedule.predict_epsilon == True

            # 预测噪声残差并计算损失。
            timestep_embedding = tf.map_fn(
                lambda t: self.get_timestep_embedding(t), timesteps, dtype=tf.float32
            )
            timestep_embedding = tf.squeeze(timestep_embedding, 1)
            model_pred = self.diffusion_model(
                [noisy_latents, timestep_embedding, encoded_text], training=True
            )
            loss = self.compiled_loss(target, model_pred)
            if self.use_mixed_precision:
                loss = self.optimizer.get_scaled_loss(loss)

        # 更新扩散模型的参数。
        trainable_vars = self.diffusion_model.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        if self.use_mixed_precision:
            gradients = self.optimizer.get_unscaled_gradients(gradients)
        gradients = [tf.clip_by_norm(g, self.max_grad_norm) for g in gradients]
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        return {m.name: m.result() for m in self.metrics}

    def get_timestep_embedding(self, timestep, dim=320, max_period=10000):
        half = dim // 2
        log_max_period = tf.math.log(tf.cast(max_period, tf.float32))
        freqs = tf.math.exp(
            -log_max_period * tf.range(0, half, dtype=tf.float32) / half
        )
        args = tf.convert_to_tensor([timestep], dtype=tf.float32) * freqs
        embedding = tf.concat([tf.math.cos(args), tf.math.sin(args)], 0)
        embedding = tf.reshape(embedding, [1, -1])
        return embedding

    def sample_from_encoder_outputs(self, outputs):
        mean, logvar = tf.split(outputs, 2, axis=-1)
        logvar = tf.clip_by_value(logvar, -30.0, 20.0)
        std = tf.exp(0.5 * logvar)
        sample = tf.random.normal(tf.shape(mean), dtype=mean.dtype)
        return mean + std * sample

    def save_weights(self, filepath, overwrite=True, save_format=None, options=None):
        # 重写此方法将允许我们直接使用 `ModelCheckpoint`
        # 回调与此培训类。在这种情况下,它将
        # 仅检查点 `diffusion_model`,因为这就是我们在微调过程中要训练的内容。
        self.diffusion_model.save_weights(
            filepath=filepath,
            overwrite=overwrite,
            save_format=save_format,
            options=options,
        )

一个重要的实现细节需要注意:我们不是直接使用图像编码器(这是一个VAE)生成的潜在向量,而是从其预测的均值和对数方差中进行采样。通过这种方式,我们可以获得更好的样本质量和多样性。

通常情况下,在微调这些模型时,会添加对混合精度训练的支持,以及模型权重的指数移动平均。然而,出于简洁性考虑,我们省略了这些元素。关于这一点,后面在教程中会有更多内容。


初始化训练器并编译它

# 如果底层GPU具有张量核心,则启用混合精度训练。
USE_MP = True
if USE_MP:
    keras.mixed_precision.set_global_policy("mixed_float16")

image_encoder = ImageEncoder()
diffusion_ft_trainer = Trainer(
    diffusion_model=DiffusionModel(RESOLUTION, RESOLUTION, MAX_PROMPT_LENGTH),
    # 从编码器中移除最上层,不再返回方差,仅返回均值。
    vae=tf.keras.Model(
        image_encoder.input,
        image_encoder.layers[-2].output,
    ),
    noise_scheduler=NoiseScheduler(),
    use_mixed_precision=USE_MP,
)

# 这些超参数来自Hugging Face的本教程:
# https://huggingface.co/docs/diffusers/training/text2image
lr = 1e-5
beta_1, beta_2 = 0.9, 0.999
weight_decay = (1e-2,)
epsilon = 1e-08

optimizer = tf.keras.optimizers.experimental.AdamW(
    learning_rate=lr,
    weight_decay=weight_decay,
    beta_1=beta_1,
    beta_2=beta_2,
    epsilon=epsilon,
)
diffusion_ft_trainer.compile(optimizer=optimizer, loss="mse")

微调

为了缩短本教程的运行时间,我们仅微调一个时期。

epochs = 1
ckpt_path = "finetuned_stable_diffusion.h5"
ckpt_callback = tf.keras.callbacks.ModelCheckpoint(
    ckpt_path,
    save_weights_only=True,
    monitor="loss",
    mode="min",
)
diffusion_ft_trainer.fit(training_dataset, epochs=epochs, callbacks=[ckpt_callback])

推理

我们在512x512的图像分辨率上将模型微调了60个时期。为了使这种分辨率下的训练成为可能,我们加入了混合精度支持。您可以查看 这个仓库 以获取更多详细信息。它还提供了对微调模型参数的指数移动平均和模型检查点的支持。

在本节中,我们将使用在60个时期微调后得到的检查点。

weights_path = tf.keras.utils.get_file(
    origin="https://huggingface.co/sayakpaul/kerascv_sd_pokemon_finetuned/resolve/main/ckpt_epochs_72_res_512_mp_True.h5"
)

img_height = img_width = 512
pokemon_model = keras_cv.models.StableDiffusion(
    img_width=img_width, img_height=img_height
)
# 我们只需重新加载微调的扩散模型的权重。
pokemon_model.diffusion_model.load_weights(weights_path)
通过使用此模型检查点,您承认其使用受CreativeML Open RAIL-M许可证的条款约束,网址为 https://raw.githubusercontent.com/CompVis/stable-diffusion/main/LICENSE

现在,我们可以对该模型进行测试。

prompts = ["尤达", "凯蒂猫", "一只红眼的精灵"]
images_to_generate = 3
outputs = {}

for prompt in prompts:
    generated_images = pokemon_model.text_to_image(
        prompt, batch_size=images_to_generate, unconditional_guidance_scale=40
    )
    outputs.update({prompt: generated_images})
25/25 [==============================] - 17s 231ms/step
25/25 [==============================] - 6s 229ms/step
25/25 [==============================] - 6s 229ms/step

经过60个时期的微调(一个好的数量大约为70),生成的图像并未达到预期效果。因此,我们尝试了Stable Diffusion在推理时所需的步数和unconditional_guidance_scale参数。

我们发现,在这个检查点下,将unconditional_guidance_scale设置为40可以得到最佳结果。

def plot_images(images, title):
    plt.figure(figsize=(20, 20))
    for i in range(len(images)):
        ax = plt.subplot(1, len(images), i + 1)
        plt.imshow(images[i])
        plt.title(title, fontsize=12)
        plt.axis("off")


for prompt in outputs:
    plot_images(outputs[prompt], prompt)

png

png

png

我们可以注意到,模型已开始适应我们数据集的风格。您可以查看 附属仓库 以获取更多比较和评论。如果您想尝试演示,您可以查看 this resource.


结论与致谢

我们展示了如何在自定义数据集上微调稳定扩散模型。虽然结果远未令人满意,但我们相信随着微调轮次的增加,结果有望改善。为了实现这一点,支持梯度累积和分布式训练至关重要。这可以被视为本教程的下一步。

稳定扩散模型还可以通过另一种有趣的方式进行微调,称为文本反转。您可以参考 this tutorial 以了解更多信息。

我们要感谢Google的ML开发者项目团队提供的GCP信用支持。我们还要感谢Hugging Face团队提供的 fine-tuning script 。它的可读性和易懂性都很好。