代码示例 / 生成式深度学习 / 通过文本反转教StableDiffusion新的概念

通过文本反转教StableDiffusion新的概念

作者: Ian Stenbit, lukewood
创建日期: 2022/12/09
最后修改: 2022/12/09
描述: 使用KerasCV的StableDiffusion实现学习新的视觉概念。

在Colab中查看 GitHub源代码


文本反转

自发布以来,StableDiffusion迅速成为生成机器学习社区的宠儿。 大量的流量导致了开源贡献的改进、繁重的提示工程,甚至新算法的发明。

也许最令人印象深刻的新算法是 文本反转,该算法在 一张图胜过千言万语:使用文本反转个性化文本到图像生成中提出。

文本反转是通过微调教导图像生成器特定视觉概念的过程。在下图中,您可以看到这个过程的示例,作者教模型新的概念,将其称为 "S_*"。

https://i.imgur.com/KqEeBsM.jpg

从概念上讲,文本反转通过学习一个新文本令牌的令牌嵌入来工作,同时保持StableDiffusion的其余组件不变。

本指南将向您展示如何使用文本反转算法微调KerasCV中提供的StableDiffusion模型。到本指南结束时,您将能够写出“作为<my-funny-cat-token>的灰袍甘道夫”。

https://i.imgur.com/rcb1Yfx.png

首先,让我们导入所需的包,并创建一个StableDiffusion实例,以便我们可以使用其一些子组件进行微调。

!pip install -q git+https://github.com/keras-team/keras-cv.git
!pip install -q tensorflow==2.11.0
import math

import keras_cv
import numpy as np
import tensorflow as tf
from keras_cv import layers as cv_layers
from keras_cv.models.stable_diffusion import NoiseScheduler
from tensorflow import keras
import matplotlib.pyplot as plt

stable_diffusion = keras_cv.models.StableDiffusion()
使用此模型检查点即表示您承认其使用受CreativeML Open RAIL-M许可证的条款限制,链接:https://raw.githubusercontent.com/CompVis/stable-diffusion/main/LICENSE

接下来,让我们定义一个可视化工具,以展示生成的图像:

def plot_images(images):
    plt.figure(figsize=(20, 20))
    for i in range(len(images)):
        ax = plt.subplot(1, len(images), i + 1)
        plt.imshow(images[i])
        plt.axis("off")

组装文本-图像对数据集

为了训练新标记的嵌入,我们首先必须组装一个包含文本-图像对的数据集。数据集的每个样本必须包含一个我们正在教给StableDiffusion的概念图像,以及一个准确表示图像内容的标题。在本教程中,我们将教给StableDiffusion Luke和Ian的GitHub头像的概念:

首先,让我们构建一个猫玩偶的图像数据集:

def assemble_image_dataset(urls):
    # 获取所有远程文件
    files = [tf.keras.utils.get_file(origin=url) for url in urls]

    # 缩放图像
    resize = keras.layers.Resizing(height=512, width=512, crop_to_aspect_ratio=True)
    images = [keras.utils.load_img(img) for img in files]
    images = [keras.utils.img_to_array(img) for img in images]
    images = np.array([resize(img) for img in images])

    # StableDiffusion图像编码器要求图像归一化到[-1, 1]的像素值范围
    images = images / 127.5 - 1

    # 创建tf.data.Dataset
    image_dataset = tf.data.Dataset.from_tensor_slices(images)

    # 随机打乱并引入随机噪声
    image_dataset = image_dataset.shuffle(50, reshuffle_each_iteration=True)
    image_dataset = image_dataset.map(
        cv_layers.RandomCropAndResize(
            target_size=(512, 512),
            crop_area_factor=(0.8, 1.0),
            aspect_ratio_factor=(1.0, 1.0),
        ),
        num_parallel_calls=tf.data.AUTOTUNE,
    )
    image_dataset = image_dataset.map(
        cv_layers.RandomFlip(mode="horizontal"),
        num_parallel_calls=tf.data.AUTOTUNE,
    )
    return image_dataset

接下来,我们组装一个文本数据集:

MAX_PROMPT_LENGTH = 77
placeholder_token = "<my-funny-cat-token>"


def pad_embedding(embedding):
    return embedding + (
        [stable_diffusion.tokenizer.end_of_text] * (MAX_PROMPT_LENGTH - len(embedding))
    )


stable_diffusion.tokenizer.add_tokens(placeholder_token)


def assemble_text_dataset(prompts):
    prompts = [prompt.format(placeholder_token) for prompt in prompts]
    embeddings = [stable_diffusion.tokenizer.encode(prompt) for prompt in prompts]
    embeddings = [np.array(pad_embedding(embedding)) for embedding in embeddings]
    text_dataset = tf.data.Dataset.from_tensor_slices(embeddings)
    text_dataset = text_dataset.shuffle(100, reshuffle_each_iteration=True)
    return text_dataset

最后,我们将数据集组合在一起以生成一个文本-图像对数据集。

def assemble_dataset(urls, prompts):
    image_dataset = assemble_image_dataset(urls)
    text_dataset = assemble_text_dataset(prompts)
    # 图像数据集相对较短,因此我们重复它以匹配文本提示数据集的长度
    image_dataset = image_dataset.repeat()
    # 我们使用文本提示数据集来确定数据集的长度。由于提示相对较少,我们将数据集重复5次。
    # 我们发现这在经验上提高了结果。
    text_dataset = text_dataset.repeat(5)
    return tf.data.Dataset.zip((image_dataset, text_dataset))

为了确保我们的提示是描述性的,我们使用非常通用的提示。

让我们试试这个样本图像和提示。

train_ds = assemble_dataset(
    urls=[
        "https://i.imgur.com/VIedH1X.jpg",
        "https://i.imgur.com/eBw13hE.png",
        "https://i.imgur.com/oJ3rSg7.png",
        "https://i.imgur.com/5mCL6Df.jpg",
        "https://i.imgur.com/4Q6WWyI.jpg",
    ],
    prompts=[
        "一张 {} 的照片",
        "一幅 {} 的渲染画",
        "一张 {} 的裁剪照片",
        "一张 {} 的照片",
        "一张干净的 {} 的照片",
        "一张黑暗的 {} 的照片",
        "一张我 {} 的照片",
        "一张酷炫的 {} 的照片",
        "一张 {} 的特写照片",
        "一张明亮的 {} 的照片",
        "一张 {} 的裁剪照片",
        "一张 {} 的照片",
        "一张好的 {} 的照片",
        "一张一个 {} 的照片",
        "一张 {} 的特写照片",
        "一幅 {} 的再现画",
        "一张干净的 {} 的照片",
        "一幅 {} 的再现画",
        "一张好看的 {} 的照片",
        "一张好的 {} 的照片",
        "一张好的 {} 的照片",
        "一张小的 {} 的照片",
        "一张奇怪的 {} 的照片",
        "一张大的 {} 的照片",
        "一张酷炫的 {} 的照片",
        "一张小的 {} 的照片",
    ],
)

关于提示准确性的重要性

在我们第一次尝试撰写本指南时,我们在数据集中包含了这些猫娃娃的群体图像,但继续使用上面列出的通用提示。 我们的结果在经验上很差。例如,这里是使用这种方法的猫娃娃甘道夫:

mediocre-wizard

它在概念上接近,但没有达到最佳效果。

为了弥补这一点,我们开始尝试将图像分割为单个猫娃娃的图像和猫娃娃的群体图像。 在这次分割后,我们为群体照想出了新的提示。

在准确表示内容的文本-图像对上进行训练显著提高了我们结果的质量。这说明了提示准确性的重要性。

除了将图像分为单个图像和群体图像外,我们还删除了一些不准确的提示,例如“黑暗的 {} 的照片”。

考虑到这一点,我们在下面组装了最终的训练数据集:

single_ds = assemble_dataset(
    urls=[
        "https://i.imgur.com/VIedH1X.jpg",
        "https://i.imgur.com/eBw13hE.png",
        "https://i.imgur.com/oJ3rSg7.png",
        "https://i.imgur.com/5mCL6Df.jpg",
        "https://i.imgur.com/4Q6WWyI.jpg",
    ],
    prompts=[
        "一张 {} 的照片",
        "一幅 {} 的渲染画",
        "一张 {} 的裁剪照片",
        "一张 {} 的照片",
        "一张干净的 {} 的照片",
        "一张我 {} 的照片",
        "一张酷炫的 {} 的照片",
        "一张 {} 的特写照片",
        "一张明亮的 {} 的照片",
        "一张 {} 的裁剪照片",
        "一张 {} 的照片",
        "一张好的 {} 的照片",
        "一张一个 {} 的照片",
        "一张 {} 的特写照片",
        "一幅 {} 的再现画",
        "一张干净的 {} 的照片",
        "一幅 {} 的再现画",
        "一张好看的 {} 的照片",
        "一张好的 {} 的照片",
        "一张好的 {} 的照片",
        "一张小的 {} 的照片",
        "一张奇怪的 {} 的照片",
        "一张大的 {} 的照片",
        "一张酷炫的 {} 的照片",
        "一张小的 {} 的照片",
    ],
)

https://i.imgur.com/gQCRjK6.png

看起来不错!

接下来,我们组装一个我们 GitHub 头像的群体数据集:

group_ds = assemble_dataset(
    urls=[
        "https://i.imgur.com/yVmZ2Qa.jpg",
        "https://i.imgur.com/JbyFbZJ.jpg",
        "https://i.imgur.com/CCubd3q.jpg",
    ],
    prompts=[
        "一张 {} 的团队照片",
        "一张 {} 的团队渲染图",
        "一张裁剪过的 {} 团队照片",
        "一张 {} 的团队照片",
        "一张干净的 {} 团队照片",
        "我的 {} 团队照片",
        "一张很酷的 {} 团队照片",
        "一张 {} 团队的特写照片",
        "一张明亮的 {} 团队照片",
        "一张裁剪过的 {} 团队照片",
        "一张 {} 的团队照片",
        "一张好的 {} 团队照片",
        "一张一组 {} 的照片",
        "一张 {} 团队的特写照片",
        "一张 {} 团队的渲染",
        "一张干净的 {} 团队照片",
        "一张 {} 团队的渲染",
        "一张漂亮的 {} 团队照片",
        "一张好的 {} 团队照片",
        "一张好的 {} 团队照片",
        "一张小型 {} 团队照片",
        "一张奇怪的 {} 团队照片",
        "一张大型 {} 团队照片",
        "一张很酷的 {} 团队照片",
        "一张小型 {} 团队照片",
    ],
)

https://i.imgur.com/GY9Pf3D.png

最后,我们将两个数据集连接在一起:

train_ds = single_ds.concatenate(group_ds)
train_ds = train_ds.batch(1).shuffle(
    train_ds.cardinality(), reshuffle_each_iteration=True
)

向文本编码器添加新令牌

接下来,我们为StableDiffusion模型创建一个新的文本编码器,并将我们的新嵌入添加到模型中,为''。

tokenized_initializer = stable_diffusion.tokenizer.encode("cat")[1]
new_weights = stable_diffusion.text_encoder.layers[2].token_embedding(
    tf.constant(tokenized_initializer)
)

# 获取.vocab的长度,而不是tokenizer
new_vocab_size = len(stable_diffusion.tokenizer.vocab)

# 嵌入层是文本编码器中的第二层
old_token_weights = stable_diffusion.text_encoder.layers[
    2
].token_embedding.get_weights()
old_position_weights = stable_diffusion.text_encoder.layers[
    2
].position_embedding.get_weights()

old_token_weights = old_token_weights[0]
new_weights = np.expand_dims(new_weights, axis=0)
new_weights = np.concatenate([old_token_weights, new_weights], axis=0)

让我们构建一个新的TextEncoder并准备它。

# 必须将download_weights设置为False,以便我们可以初始化(否则会尝试加载权重)
new_encoder = keras_cv.models.stable_diffusion.TextEncoder(
    keras_cv.models.stable_diffusion.stable_diffusion.MAX_PROMPT_LENGTH,
    vocab_size=new_vocab_size,
    download_weights=False,
)
for index, layer in enumerate(stable_diffusion.text_encoder.layers):
    # 第二层是嵌入层,因此我们在复制权重时省略它
    if index == 2:
        continue
    new_encoder.layers[index].set_weights(layer.get_weights())


new_encoder.layers[2].token_embedding.set_weights([new_weights])
new_encoder.layers[2].position_embedding.set_weights(old_position_weights)

stable_diffusion._text_encoder = new_encoder
stable_diffusion._text_encoder.compile(jit_compile=True)

训练

现在我们可以进入激动人心的部分:训练!

在TextualInversion中,唯一要训练的模型部分是嵌入向量。 让我们冻结模型的其余部分。

stable_diffusion.diffusion_model.trainable = False
stable_diffusion.decoder.trainable = False
stable_diffusion.text_encoder.trainable = True

stable_diffusion.text_encoder.layers[2].trainable = True


def traverse_layers(layer):
    if hasattr(layer, "layers"):
        for layer in layer.layers:
            yield layer
    if hasattr(layer, "token_embedding"):
        yield layer.token_embedding
    if hasattr(layer, "position_embedding"):
        yield layer.position_embedding


for layer in traverse_layers(stable_diffusion.text_encoder):
    if isinstance(layer, keras.layers.Embedding) or "clip_embedding" in layer.name:
        layer.trainable = True
    else:
        layer.trainable = False

new_encoder.layers[2].position_embedding.trainable = False

让我们确认设置为可训练的权重是正确的。

all_models = [
    stable_diffusion.text_encoder,
    stable_diffusion.diffusion_model,
    stable_diffusion.decoder,
]
print([[w.shape for w in model.trainable_weights] for model in all_models])
[[TensorShape([49409, 768])], [], []]

训练新的嵌入

为了训练嵌入,我们需要几个工具。 我们从KerasCV导入一个NoiseScheduler,并在下面定义以下工具:

  • sample_from_encoder_outputs 是一个包装器,用于基础StableDiffusion图像编码器,它从图像编码器产生的统计分布中采样,而不是像许多其他SD应用一样仅取平均值。
  • get_timestep_embedding 为扩散模型产生指定时间步的嵌入。
  • get_position_ids 为文本编码器生成一个位置ID的张量(它只是从[1, MAX_PROMPT_LENGTH] 的一系列数字)。
# 从编码器中移除顶部层,这将切断方差并仅返回
# 均值
training_image_encoder = keras.Model(
    stable_diffusion.image_encoder.input,
    stable_diffusion.image_encoder.layers[-2].output,
)


def sample_from_encoder_outputs(outputs):
    mean, logvar = tf.split(outputs, 2, axis=-1)
    logvar = tf.clip_by_value(logvar, -30.0, 20.0)
    std = tf.exp(0.5 * logvar)
    sample = tf.random.normal(tf.shape(mean))
    return mean + std * sample


def get_timestep_embedding(timestep, dim=320, max_period=10000):
    half = dim // 2
    freqs = tf.math.exp(
        -math.log(max_period) * tf.range(0, half, dtype=tf.float32) / half
    )
    args = tf.convert_to_tensor([timestep], dtype=tf.float32) * freqs
    embedding = tf.concat([tf.math.cos(args), tf.math.sin(args)], 0)
    return embedding


def get_position_ids():
    return tf.convert_to_tensor([list(range(MAX_PROMPT_LENGTH))], dtype=tf.int32)

接下来,我们实现一个 StableDiffusionFineTuner,它是 keras.Model 的子类,重载了 train_step 以训练我们文本编码器的令牌嵌入。这是文本反转算法的核心。

抽象地说,训练步骤从冻结的 SD 图像编码器的潜在分布中取样训练图像,给该样本添加噪声,然后将这个带噪声的样本传递给冻结的扩散模型。扩散模型的隐藏状态是对应于图像提示的文本编码器的输出。

我们最终的目标状态是,扩散模型能够使用文本编码作为隐藏状态从样本中分离噪声,因此我们的损失是噪声与扩散模型输出的均方误差(理想情况下,扩散模型已从噪声中去除了图像潜在值)。

我们仅计算文本编码器的令牌嵌入的梯度,在训练步骤中,我们将除了我们正在学习的令牌之外的所有令牌的梯度清零。

有关训练步骤的更多详细信息,请参见内联代码注释。

class StableDiffusionFineTuner(keras.Model):
    def __init__(self, stable_diffusion, noise_scheduler, **kwargs):
        super().__init__(**kwargs)
        self.stable_diffusion = stable_diffusion
        self.noise_scheduler = noise_scheduler

    def train_step(self, data):
        images, embeddings = data

        with tf.GradientTape() as tape:
            # 从训练图像的预测分布中取样
            latents = sample_from_encoder_outputs(training_image_encoder(images))
            # 潜在值必须下采样以匹配训练 StableDiffusion 时使用的潜在值的尺度。
            # 这个数字确实只是他们在训练模型时选择的一个“魔术”常数。
            latents = latents * 0.18215

            # 产生与潜在样本形状相同的随机噪声
            noise = tf.random.normal(tf.shape(latents))
            batch_dim = tf.shape(latents)[0]

            # 为批中的每个样本随机选择一个时间步
            timesteps = tf.random.uniform(
                (batch_dim,),
                minval=0,
                maxval=noise_scheduler.train_timesteps,
                dtype=tf.int64,
            )

            # 根据每个样本的时间步将噪声添加到潜在值中
            noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)

            # 编码训练样本中的文本以用作扩散模型中的隐藏状态
            encoder_hidden_state = self.stable_diffusion.text_encoder(
                [embeddings, get_position_ids()]
            )

            # 为批中每个样本随机选择的时间步计算时间步嵌入
            timestep_embeddings = tf.map_fn(
                fn=get_timestep_embedding,
                elems=timesteps,
                fn_output_signature=tf.float32,
            )

            # 调用扩散模型
            noise_pred = self.stable_diffusion.diffusion_model(
                [noisy_latents, timestep_embeddings, encoder_hidden_state]
            )

            # 计算均方误差损失并进行减少。
            loss = self.compiled_loss(noise_pred, noise)
            loss = tf.reduce_mean(loss, axis=2)
            loss = tf.reduce_mean(loss, axis=1)
            loss = tf.reduce_mean(loss)

        # 加载可训练权重并计算它们的梯度
        trainable_weights = self.stable_diffusion.text_encoder.trainable_weights
        grads = tape.gradient(loss, trainable_weights)

        # 梯度存储在索引切片中,因此我们必须找到包含占位符令牌的切片的索引。
        index_of_placeholder_token = tf.reshape(tf.where(grads[0].indices == 49408), ())
        condition = grads[0].indices == 49408
        condition = tf.expand_dims(condition, axis=-1)

        # 重写梯度,将所有不是占位符令牌的切片的梯度清零,有效冻结其他令牌的权重。
        grads[0] = tf.IndexedSlices(
            values=tf.where(condition, grads[0].values, 0),
            indices=grads[0].indices,
            dense_shape=grads[0].dense_shape,
        )

        self.optimizer.apply_gradients(zip(grads, trainable_weights))
        return {"loss": loss}

在我们开始训练之前,让我们看看 StableDiffusion 为我们的令牌生成了什么。

generated = stable_diffusion.text_to_image(
    f"an oil painting of {placeholder_token}", seed=1337, batch_size=3
)
plot_images(generated)
25/25 [==============================] - 19s 314ms/step

png

如您所见,模型仍然将我们的令牌视为一只猫,因为这是我们用来初始化自定义令牌的种子令牌。

现在,为了开始训练,我们可以像其他 Keras 模型一样简单地 compile() 我们的模型。在此之前,我们还为训练实例化了一个噪声调度器,并配置了诸如学习率和优化器等训练参数。

noise_scheduler = NoiseScheduler(
    beta_start=0.00085,
    beta_end=0.012,
    beta_schedule="scaled_linear",
    train_timesteps=1000,
)
trainer = StableDiffusionFineTuner(stable_diffusion, noise_scheduler, name="trainer")
EPOCHS = 50
learning_rate = keras.optimizers.schedules.CosineDecay(
    initial_learning_rate=1e-4, decay_steps=train_ds.cardinality() * EPOCHS
)
optimizer = keras.optimizers.Adam(
    weight_decay=0.004, learning_rate=learning_rate, epsilon=1e-8, global_clipnorm=10
)

trainer.compile(
    optimizer=optimizer,
    # 我们在训练步骤中手动执行了缩减,因此这里不需要。
    loss=keras.losses.MeanSquaredError(reduction="none"),
)

为了监控训练,我们可以生成一个 keras.callbacks.Callback 以使用我们的自定义令牌每个周期生成一些图像。

我们创建了三个具有不同提示的回调,以便我们可以看到它们在训练过程中的进展。我们使用固定种子,以便容易查看学习到的令牌的进展。

class GenerateImages(keras.callbacks.Callback):
    def __init__(
        self, stable_diffusion, prompt, steps=50, frequency=10, seed=None, **kwargs
    ):
        super().__init__(**kwargs)
        self.stable_diffusion = stable_diffusion
        self.prompt = prompt
        self.seed = seed
        self.frequency = frequency
        self.steps = steps

    def on_epoch_end(self, epoch, logs):
        if epoch % self.frequency == 0:
            images = self.stable_diffusion.text_to_image(
                self.prompt, batch_size=3, num_steps=self.steps, seed=self.seed
            )
            plot_images(
                images,
            )


cbs = [
    GenerateImages(
        stable_diffusion, prompt=f"一幅 {placeholder_token} 的油画", seed=1337
    ),
    GenerateImages(
        stable_diffusion, prompt=f"甘道夫灰衣作为一个 {placeholder_token}", seed=1337
    ),
    GenerateImages(
        stable_diffusion,
        prompt=f"两个 {placeholder_token} 结婚,照片真实感,高质量",
        seed=1337,
    ),
]

现在,剩下的就是调用 model.fit()

trainer.fit(
    train_ds,
    epochs=EPOCHS,
    callbacks=cbs,
)
Epoch 1/50 50/50 [==============================] - 16s 318ms/step 50/50 [==============================] - 16s 318ms/step 50/50 [==============================] - 16s 318ms/step 250/250 [==============================] - 194s 469ms/step - loss: 0.1533 Epoch 2/50 250/250 [==============================] - 68s 269ms/step - loss: 0.1557 Epoch 3/50 250/250 [==============================] - 68s 269ms/step - loss: 0.1359 Epoch 4/50 250/250 [==============================] - 68s 269ms/step - loss: 0.1693 Epoch 5/50 250/250 [==============================] - 68s 269ms/step - loss: 0.1475 Epoch 6/50 250/250 [==============================] - 68s 268ms/step - loss: 0.1472 Epoch 7/50 250/250 [==============================] - 68s 268ms/step - loss: 0.1533 Epoch 8/50 250/250 [==============================] - 68s 268ms/step - loss: 0.1450 Epoch 9/50 250/250 [==============================] - 68s 268ms/step - loss: 0.1639 Epoch 10/50 250/250 [==============================] - 68s 269ms/step - loss: 0.1351 Epoch 11/50 50/50 [==============================] - 16s 316ms/step 50/50 [==============================] - 16s 316ms/step 50/50 [==============================] - 16s 317ms/step 250/250 [==============================] - 116s 464ms/step - loss: 0.1474 Epoch 12/50 250/250 [==============================] - 68s 268ms/step - loss: 0.1737 Epoch 13/50 250/250 [==============================] - 68s 269ms/step - loss: 0.1427 Epoch 14/50 250/250 [==============================] - 68s 269ms/step - loss: 0.1698 Epoch 15/50 250/250 [==============================] - 68s 270ms/step - loss: 0.1424 Epoch 16/50 250/250 [==============================] - 68s 268ms/step - loss: 0.1339 Epoch 17/50 250/250 [==============================] - 68s 268ms/step - loss: 0.1397 Epoch 18/50 250/250 [==============================] - 68s 268ms/step - loss: 0.1469 Epoch 19/50 250/250 [==============================] - 67s 267ms/step - loss: 0.1649 Epoch 20/50 250/250 [==============================] - 68s 268ms/step - loss: 0.1582 Epoch 21/50 50/50 [==============================] - 16s 315ms/step 50/50 [==============================] - 16s 316ms/step 50/50 [==============================] - 16s 316ms/step 250/250 [==============================] - 116s 462ms/step - loss: 0.1331 Epoch 22/50 250/250 [==============================] - 67s 267ms/step - loss: 0.1319 Epoch 23/50 250/250 [==============================] - 68s 267ms/step - loss: 0.1521 Epoch 24/50 250/250 [==============================] - 67s 267ms/step - loss: 0.1486 Epoch 25/50 250/250 [==============================] - 68s 267ms/step - loss: 0.1449 Epoch 26/50 250/250 [==============================] - 67s 266ms/step - loss: 0.1349 Epoch 27/50 250/250 [==============================] - 67s 267ms/step - loss: 0.1454 Epoch 28/50 250/250 [==============================] - 68s 268ms/step - loss: 0.1394 Epoch 29/50 250/250 [==============================] - 67s 267ms/step - loss: 0.1489 Epoch 30/50 250/250 [==============================] - 67s 267ms/step - loss: 0.1338 Epoch 31/50 50/50 [==============================] - 16s 315ms/step 50/50 [==============================] - 16s 320ms/step 50/50 [==============================] - 16s 315ms/step 250/250 [==============================] - 116s 462ms/step - loss: 0.1328 Epoch 32/50 250/250 [==============================] - 67s 267ms/step - loss: 0.1693 Epoch 33/50 250/250 [==============================] - 67s 266ms/step - loss: 0.1420 Epoch 34/50 250/250 [==============================] - 67s 266ms/step - loss: 0.1255 Epoch 35/50 250/250 [==============================] - 67s 266ms/step - loss: 0.1239 Epoch 36/50 250/250 [==============================] - 67s 267ms/step - loss: 0.1558 Epoch 37/50 250/250 [==============================] - 68s 267ms/step - loss: 0.1527 Epoch 38/50 250/250 [==============================] - 67s 267ms/step - loss: 0.1461 Epoch 39/50 250/250 [==============================] - 67s 267ms/step - loss: 0.1555 Epoch 40/50 250/250 [==============================] - 67s 266ms/step - loss: 0.1515 Epoch 41/50 50/50 [==============================] - 16s 315ms/step 50/50 [==============================] - 16s 315ms/step 50/50 [==============================] - 16s 315ms/step 250/250 [==============================] - 116s 461ms/step - loss: 0.1291 Epoch 42/50 250/250 [==============================] - 67s 266ms/step - loss: 0.1474 Epoch 43/50 250/250 [==============================] - 67s 267ms/step - loss: 0.1908 Epoch 44/50 250/250 [==============================] - 67s 267ms/step - loss: 0.1506 Epoch 45/50 250/250 [==============================] - 68s 267ms/step - loss: 0.1424 Epoch 46/50 250/250 [==============================] - 67s 266ms/step - loss: 0.1601 Epoch 47/50 250/250 [==============================] - 67s 266ms/step - loss: 0.1312 Epoch 48/50 250/250 [==============================] - 67s 266ms/step - loss: 0.1524 Epoch 49/50 250/250 [==============================] - 67s 266ms/step - loss: 0.1477 Epoch 50/50 250/250 [==============================] - 67s 267ms/step - loss: 0.1397
</div>

![png](/img/examples/generative/fine_tune_via_textual_inversion/fine_tune_via_textual_inversion_39_2.png)





![png](/img/examples/generative/fine_tune_via_textual_inversion/fine_tune_via_textual_inversion_39_3.png)





![png](/img/examples/generative/fine_tune_via_textual_inversion/fine_tune_via_textual_inversion_39_4.png)





![png](/img/examples/generative/fine_tune_via_textual_inversion/fine_tune_via_textual_inversion_39_5.png)





![png](/img/examples/generative/fine_tune_via_textual_inversion/fine_tune_via_textual_inversion_39_6.png)





![png](/img/examples/generative/fine_tune_via_textual_inversion/fine_tune_via_textual_inversion_39_7.png)





![png](/img/examples/generative/fine_tune_via_textual_inversion/fine_tune_via_textual_inversion_39_8.png)





![png](/img/examples/generative/fine_tune_via_textual_inversion/fine_tune_via_textual_inversion_39_9.png)





![png](/img/examples/generative/fine_tune_via_textual_inversion/fine_tune_via_textual_inversion_39_10.png)





![png](/img/examples/generative/fine_tune_via_textual_inversion/fine_tune_via_textual_inversion_39_11.png)





![png](/img/examples/generative/fine_tune_via_textual_inversion/fine_tune_via_textual_inversion_39_12.png)





![png](/img/examples/generative/fine_tune_via_textual_inversion/fine_tune_via_textual_inversion_39_13.png)





![png](/img/examples/generative/fine_tune_via_textual_inversion/fine_tune_via_textual_inversion_39_14.png)





![png](/img/examples/generative/fine_tune_via_textual_inversion/fine_tune_via_textual_inversion_39_15.png)





![png](/img/examples/generative/fine_tune_via_textual_inversion/fine_tune_via_textual_inversion_39_16.png)



看到模型如何随着时间学习我们的新标记真的很有趣。尽情尝试,看看如何调整训练参数和您的训练数据集,以生成最佳图像!

---
## 测试微调模型

现在是非常有趣的部分。我们为我们的自定义标记学习了一个标记嵌入,因此现在我们可以像对待任何其他标记一样使用稳定扩散生成图像!

这里有一些有趣的示例提示来帮助您入门,以及我们猫玩具标记的示例输出!


```python
generated = stable_diffusion.text_to_image(
    f"Gandalf as a {placeholder_token} fantasy art drawn by disney concept artists, "
    "golden colour, high quality, highly detailed, elegant, sharp focus, concept art, "
    "character concepts, digital painting, mystery, adventure",
    batch_size=3,
)
plot_images(generated)
25/25 [==============================] - 8s 316ms/step
![png](/img/examples/generative/fine_tune_via_textual_inversion/fine_tune_via_textual_inversion_42_1.png)
generated = stable_diffusion.text_to_image(
    f"A masterpiece of a {placeholder_token} crying out to the heavens. "
    f"Behind the {placeholder_token}, an dark, evil shade looms over it - sucking the "
    "life right out of it.",
    batch_size=3,
)
plot_images(generated)
25/25 [==============================] - 8s 314ms/step
![png](/img/examples/generative/fine_tune_via_textual_inversion/fine_tune_via_textual_inversion_43_1.png)
generated = stable_diffusion.text_to_image(
    f"An evil {placeholder_token}.", batch_size=3
)
plot_images(generated)
25/25 [==============================] - 8s 322ms/step
![png](/img/examples/generative/fine_tune_via_textual_inversion/fine_tune_via_textual_inversion_44_1.png)
generated = stable_diffusion.text_to_image(
    f"A mysterious {placeholder_token} approaches the great pyramids of egypt.",
    batch_size=3,
)
plot_images(generated)
25/25 [==============================] - 8s 315ms/step
![png](/img/examples/generative/fine_tune_via_textual_inversion/fine_tune_via_textual_inversion_45_1.png) --- ## 结论 使用文本反转算法,您可以教会稳定扩散新的概念! 一些可能的后续步骤: - 尝试您自己的提示 - 教给模型一种风格 - 收集您最喜欢的宠物猫或狗的数据集,并教给模型有关它的信息