作者: Ian Stenbit, lukewood
创建日期: 2022/12/09
最后修改: 2022/12/09
描述: 使用KerasCV的StableDiffusion实现学习新的视觉概念。
自发布以来,StableDiffusion迅速成为生成机器学习社区的宠儿。 大量的流量导致了开源贡献的改进、繁重的提示工程,甚至新算法的发明。
也许最令人印象深刻的新算法是 文本反转,该算法在 一张图胜过千言万语:使用文本反转个性化文本到图像生成中提出。
文本反转是通过微调教导图像生成器特定视觉概念的过程。在下图中,您可以看到这个过程的示例,作者教模型新的概念,将其称为 "S_*"。
从概念上讲,文本反转通过学习一个新文本令牌的令牌嵌入来工作,同时保持StableDiffusion的其余组件不变。
本指南将向您展示如何使用文本反转算法微调KerasCV中提供的StableDiffusion模型。到本指南结束时,您将能够写出“作为<my-funny-cat-token>的灰袍甘道夫”。
首先,让我们导入所需的包,并创建一个StableDiffusion实例,以便我们可以使用其一些子组件进行微调。
!pip install -q git+https://github.com/keras-team/keras-cv.git
!pip install -q tensorflow==2.11.0
import math
import keras_cv
import numpy as np
import tensorflow as tf
from keras_cv import layers as cv_layers
from keras_cv.models.stable_diffusion import NoiseScheduler
from tensorflow import keras
import matplotlib.pyplot as plt
stable_diffusion = keras_cv.models.StableDiffusion()
使用此模型检查点即表示您承认其使用受CreativeML Open RAIL-M许可证的条款限制,链接:https://raw.githubusercontent.com/CompVis/stable-diffusion/main/LICENSE
接下来,让我们定义一个可视化工具,以展示生成的图像:
def plot_images(images):
plt.figure(figsize=(20, 20))
for i in range(len(images)):
ax = plt.subplot(1, len(images), i + 1)
plt.imshow(images[i])
plt.axis("off")
为了训练新标记的嵌入,我们首先必须组装一个包含文本-图像对的数据集。数据集的每个样本必须包含一个我们正在教给StableDiffusion的概念图像,以及一个准确表示图像内容的标题。在本教程中,我们将教给StableDiffusion Luke和Ian的GitHub头像的概念:
首先,让我们构建一个猫玩偶的图像数据集:
def assemble_image_dataset(urls):
# 获取所有远程文件
files = [tf.keras.utils.get_file(origin=url) for url in urls]
# 缩放图像
resize = keras.layers.Resizing(height=512, width=512, crop_to_aspect_ratio=True)
images = [keras.utils.load_img(img) for img in files]
images = [keras.utils.img_to_array(img) for img in images]
images = np.array([resize(img) for img in images])
# StableDiffusion图像编码器要求图像归一化到[-1, 1]的像素值范围
images = images / 127.5 - 1
# 创建tf.data.Dataset
image_dataset = tf.data.Dataset.from_tensor_slices(images)
# 随机打乱并引入随机噪声
image_dataset = image_dataset.shuffle(50, reshuffle_each_iteration=True)
image_dataset = image_dataset.map(
cv_layers.RandomCropAndResize(
target_size=(512, 512),
crop_area_factor=(0.8, 1.0),
aspect_ratio_factor=(1.0, 1.0),
),
num_parallel_calls=tf.data.AUTOTUNE,
)
image_dataset = image_dataset.map(
cv_layers.RandomFlip(mode="horizontal"),
num_parallel_calls=tf.data.AUTOTUNE,
)
return image_dataset
接下来,我们组装一个文本数据集:
MAX_PROMPT_LENGTH = 77
placeholder_token = "<my-funny-cat-token>"
def pad_embedding(embedding):
return embedding + (
[stable_diffusion.tokenizer.end_of_text] * (MAX_PROMPT_LENGTH - len(embedding))
)
stable_diffusion.tokenizer.add_tokens(placeholder_token)
def assemble_text_dataset(prompts):
prompts = [prompt.format(placeholder_token) for prompt in prompts]
embeddings = [stable_diffusion.tokenizer.encode(prompt) for prompt in prompts]
embeddings = [np.array(pad_embedding(embedding)) for embedding in embeddings]
text_dataset = tf.data.Dataset.from_tensor_slices(embeddings)
text_dataset = text_dataset.shuffle(100, reshuffle_each_iteration=True)
return text_dataset
最后,我们将数据集组合在一起以生成一个文本-图像对数据集。
def assemble_dataset(urls, prompts):
image_dataset = assemble_image_dataset(urls)
text_dataset = assemble_text_dataset(prompts)
# 图像数据集相对较短,因此我们重复它以匹配文本提示数据集的长度
image_dataset = image_dataset.repeat()
# 我们使用文本提示数据集来确定数据集的长度。由于提示相对较少,我们将数据集重复5次。
# 我们发现这在经验上提高了结果。
text_dataset = text_dataset.repeat(5)
return tf.data.Dataset.zip((image_dataset, text_dataset))
为了确保我们的提示是描述性的,我们使用非常通用的提示。
让我们试试这个样本图像和提示。
train_ds = assemble_dataset(
urls=[
"https://i.imgur.com/VIedH1X.jpg",
"https://i.imgur.com/eBw13hE.png",
"https://i.imgur.com/oJ3rSg7.png",
"https://i.imgur.com/5mCL6Df.jpg",
"https://i.imgur.com/4Q6WWyI.jpg",
],
prompts=[
"一张 {} 的照片",
"一幅 {} 的渲染画",
"一张 {} 的裁剪照片",
"一张 {} 的照片",
"一张干净的 {} 的照片",
"一张黑暗的 {} 的照片",
"一张我 {} 的照片",
"一张酷炫的 {} 的照片",
"一张 {} 的特写照片",
"一张明亮的 {} 的照片",
"一张 {} 的裁剪照片",
"一张 {} 的照片",
"一张好的 {} 的照片",
"一张一个 {} 的照片",
"一张 {} 的特写照片",
"一幅 {} 的再现画",
"一张干净的 {} 的照片",
"一幅 {} 的再现画",
"一张好看的 {} 的照片",
"一张好的 {} 的照片",
"一张好的 {} 的照片",
"一张小的 {} 的照片",
"一张奇怪的 {} 的照片",
"一张大的 {} 的照片",
"一张酷炫的 {} 的照片",
"一张小的 {} 的照片",
],
)
在我们第一次尝试撰写本指南时,我们在数据集中包含了这些猫娃娃的群体图像,但继续使用上面列出的通用提示。 我们的结果在经验上很差。例如,这里是使用这种方法的猫娃娃甘道夫:
它在概念上接近,但没有达到最佳效果。
为了弥补这一点,我们开始尝试将图像分割为单个猫娃娃的图像和猫娃娃的群体图像。 在这次分割后,我们为群体照想出了新的提示。
在准确表示内容的文本-图像对上进行训练显著提高了我们结果的质量。这说明了提示准确性的重要性。
除了将图像分为单个图像和群体图像外,我们还删除了一些不准确的提示,例如“黑暗的 {} 的照片”。
考虑到这一点,我们在下面组装了最终的训练数据集:
single_ds = assemble_dataset(
urls=[
"https://i.imgur.com/VIedH1X.jpg",
"https://i.imgur.com/eBw13hE.png",
"https://i.imgur.com/oJ3rSg7.png",
"https://i.imgur.com/5mCL6Df.jpg",
"https://i.imgur.com/4Q6WWyI.jpg",
],
prompts=[
"一张 {} 的照片",
"一幅 {} 的渲染画",
"一张 {} 的裁剪照片",
"一张 {} 的照片",
"一张干净的 {} 的照片",
"一张我 {} 的照片",
"一张酷炫的 {} 的照片",
"一张 {} 的特写照片",
"一张明亮的 {} 的照片",
"一张 {} 的裁剪照片",
"一张 {} 的照片",
"一张好的 {} 的照片",
"一张一个 {} 的照片",
"一张 {} 的特写照片",
"一幅 {} 的再现画",
"一张干净的 {} 的照片",
"一幅 {} 的再现画",
"一张好看的 {} 的照片",
"一张好的 {} 的照片",
"一张好的 {} 的照片",
"一张小的 {} 的照片",
"一张奇怪的 {} 的照片",
"一张大的 {} 的照片",
"一张酷炫的 {} 的照片",
"一张小的 {} 的照片",
],
)
看起来不错!
接下来,我们组装一个我们 GitHub 头像的群体数据集:
group_ds = assemble_dataset(
urls=[
"https://i.imgur.com/yVmZ2Qa.jpg",
"https://i.imgur.com/JbyFbZJ.jpg",
"https://i.imgur.com/CCubd3q.jpg",
],
prompts=[
"一张 {} 的团队照片",
"一张 {} 的团队渲染图",
"一张裁剪过的 {} 团队照片",
"一张 {} 的团队照片",
"一张干净的 {} 团队照片",
"我的 {} 团队照片",
"一张很酷的 {} 团队照片",
"一张 {} 团队的特写照片",
"一张明亮的 {} 团队照片",
"一张裁剪过的 {} 团队照片",
"一张 {} 的团队照片",
"一张好的 {} 团队照片",
"一张一组 {} 的照片",
"一张 {} 团队的特写照片",
"一张 {} 团队的渲染",
"一张干净的 {} 团队照片",
"一张 {} 团队的渲染",
"一张漂亮的 {} 团队照片",
"一张好的 {} 团队照片",
"一张好的 {} 团队照片",
"一张小型 {} 团队照片",
"一张奇怪的 {} 团队照片",
"一张大型 {} 团队照片",
"一张很酷的 {} 团队照片",
"一张小型 {} 团队照片",
],
)
最后,我们将两个数据集连接在一起:
train_ds = single_ds.concatenate(group_ds)
train_ds = train_ds.batch(1).shuffle(
train_ds.cardinality(), reshuffle_each_iteration=True
)
接下来,我们为StableDiffusion模型创建一个新的文本编码器,并将我们的新嵌入添加到模型中,为'
tokenized_initializer = stable_diffusion.tokenizer.encode("cat")[1]
new_weights = stable_diffusion.text_encoder.layers[2].token_embedding(
tf.constant(tokenized_initializer)
)
# 获取.vocab的长度,而不是tokenizer
new_vocab_size = len(stable_diffusion.tokenizer.vocab)
# 嵌入层是文本编码器中的第二层
old_token_weights = stable_diffusion.text_encoder.layers[
2
].token_embedding.get_weights()
old_position_weights = stable_diffusion.text_encoder.layers[
2
].position_embedding.get_weights()
old_token_weights = old_token_weights[0]
new_weights = np.expand_dims(new_weights, axis=0)
new_weights = np.concatenate([old_token_weights, new_weights], axis=0)
让我们构建一个新的TextEncoder并准备它。
# 必须将download_weights设置为False,以便我们可以初始化(否则会尝试加载权重)
new_encoder = keras_cv.models.stable_diffusion.TextEncoder(
keras_cv.models.stable_diffusion.stable_diffusion.MAX_PROMPT_LENGTH,
vocab_size=new_vocab_size,
download_weights=False,
)
for index, layer in enumerate(stable_diffusion.text_encoder.layers):
# 第二层是嵌入层,因此我们在复制权重时省略它
if index == 2:
continue
new_encoder.layers[index].set_weights(layer.get_weights())
new_encoder.layers[2].token_embedding.set_weights([new_weights])
new_encoder.layers[2].position_embedding.set_weights(old_position_weights)
stable_diffusion._text_encoder = new_encoder
stable_diffusion._text_encoder.compile(jit_compile=True)
现在我们可以进入激动人心的部分:训练!
在TextualInversion中,唯一要训练的模型部分是嵌入向量。 让我们冻结模型的其余部分。
stable_diffusion.diffusion_model.trainable = False
stable_diffusion.decoder.trainable = False
stable_diffusion.text_encoder.trainable = True
stable_diffusion.text_encoder.layers[2].trainable = True
def traverse_layers(layer):
if hasattr(layer, "layers"):
for layer in layer.layers:
yield layer
if hasattr(layer, "token_embedding"):
yield layer.token_embedding
if hasattr(layer, "position_embedding"):
yield layer.position_embedding
for layer in traverse_layers(stable_diffusion.text_encoder):
if isinstance(layer, keras.layers.Embedding) or "clip_embedding" in layer.name:
layer.trainable = True
else:
layer.trainable = False
new_encoder.layers[2].position_embedding.trainable = False
让我们确认设置为可训练的权重是正确的。
all_models = [
stable_diffusion.text_encoder,
stable_diffusion.diffusion_model,
stable_diffusion.decoder,
]
print([[w.shape for w in model.trainable_weights] for model in all_models])
[[TensorShape([49409, 768])], [], []]
为了训练嵌入,我们需要几个工具。 我们从KerasCV导入一个NoiseScheduler,并在下面定义以下工具:
sample_from_encoder_outputs
是一个包装器,用于基础StableDiffusion图像编码器,它从图像编码器产生的统计分布中采样,而不是像许多其他SD应用一样仅取平均值。get_timestep_embedding
为扩散模型产生指定时间步的嵌入。get_position_ids
为文本编码器生成一个位置ID的张量(它只是从[1, MAX_PROMPT_LENGTH]
的一系列数字)。# 从编码器中移除顶部层,这将切断方差并仅返回
# 均值
training_image_encoder = keras.Model(
stable_diffusion.image_encoder.input,
stable_diffusion.image_encoder.layers[-2].output,
)
def sample_from_encoder_outputs(outputs):
mean, logvar = tf.split(outputs, 2, axis=-1)
logvar = tf.clip_by_value(logvar, -30.0, 20.0)
std = tf.exp(0.5 * logvar)
sample = tf.random.normal(tf.shape(mean))
return mean + std * sample
def get_timestep_embedding(timestep, dim=320, max_period=10000):
half = dim // 2
freqs = tf.math.exp(
-math.log(max_period) * tf.range(0, half, dtype=tf.float32) / half
)
args = tf.convert_to_tensor([timestep], dtype=tf.float32) * freqs
embedding = tf.concat([tf.math.cos(args), tf.math.sin(args)], 0)
return embedding
def get_position_ids():
return tf.convert_to_tensor([list(range(MAX_PROMPT_LENGTH))], dtype=tf.int32)
接下来,我们实现一个 StableDiffusionFineTuner
,它是 keras.Model
的子类,重载了 train_step
以训练我们文本编码器的令牌嵌入。这是文本反转算法的核心。
抽象地说,训练步骤从冻结的 SD 图像编码器的潜在分布中取样训练图像,给该样本添加噪声,然后将这个带噪声的样本传递给冻结的扩散模型。扩散模型的隐藏状态是对应于图像提示的文本编码器的输出。
我们最终的目标状态是,扩散模型能够使用文本编码作为隐藏状态从样本中分离噪声,因此我们的损失是噪声与扩散模型输出的均方误差(理想情况下,扩散模型已从噪声中去除了图像潜在值)。
我们仅计算文本编码器的令牌嵌入的梯度,在训练步骤中,我们将除了我们正在学习的令牌之外的所有令牌的梯度清零。
有关训练步骤的更多详细信息,请参见内联代码注释。
class StableDiffusionFineTuner(keras.Model):
def __init__(self, stable_diffusion, noise_scheduler, **kwargs):
super().__init__(**kwargs)
self.stable_diffusion = stable_diffusion
self.noise_scheduler = noise_scheduler
def train_step(self, data):
images, embeddings = data
with tf.GradientTape() as tape:
# 从训练图像的预测分布中取样
latents = sample_from_encoder_outputs(training_image_encoder(images))
# 潜在值必须下采样以匹配训练 StableDiffusion 时使用的潜在值的尺度。
# 这个数字确实只是他们在训练模型时选择的一个“魔术”常数。
latents = latents * 0.18215
# 产生与潜在样本形状相同的随机噪声
noise = tf.random.normal(tf.shape(latents))
batch_dim = tf.shape(latents)[0]
# 为批中的每个样本随机选择一个时间步
timesteps = tf.random.uniform(
(batch_dim,),
minval=0,
maxval=noise_scheduler.train_timesteps,
dtype=tf.int64,
)
# 根据每个样本的时间步将噪声添加到潜在值中
noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
# 编码训练样本中的文本以用作扩散模型中的隐藏状态
encoder_hidden_state = self.stable_diffusion.text_encoder(
[embeddings, get_position_ids()]
)
# 为批中每个样本随机选择的时间步计算时间步嵌入
timestep_embeddings = tf.map_fn(
fn=get_timestep_embedding,
elems=timesteps,
fn_output_signature=tf.float32,
)
# 调用扩散模型
noise_pred = self.stable_diffusion.diffusion_model(
[noisy_latents, timestep_embeddings, encoder_hidden_state]
)
# 计算均方误差损失并进行减少。
loss = self.compiled_loss(noise_pred, noise)
loss = tf.reduce_mean(loss, axis=2)
loss = tf.reduce_mean(loss, axis=1)
loss = tf.reduce_mean(loss)
# 加载可训练权重并计算它们的梯度
trainable_weights = self.stable_diffusion.text_encoder.trainable_weights
grads = tape.gradient(loss, trainable_weights)
# 梯度存储在索引切片中,因此我们必须找到包含占位符令牌的切片的索引。
index_of_placeholder_token = tf.reshape(tf.where(grads[0].indices == 49408), ())
condition = grads[0].indices == 49408
condition = tf.expand_dims(condition, axis=-1)
# 重写梯度,将所有不是占位符令牌的切片的梯度清零,有效冻结其他令牌的权重。
grads[0] = tf.IndexedSlices(
values=tf.where(condition, grads[0].values, 0),
indices=grads[0].indices,
dense_shape=grads[0].dense_shape,
)
self.optimizer.apply_gradients(zip(grads, trainable_weights))
return {"loss": loss}
在我们开始训练之前,让我们看看 StableDiffusion 为我们的令牌生成了什么。
generated = stable_diffusion.text_to_image(
f"an oil painting of {placeholder_token}", seed=1337, batch_size=3
)
plot_images(generated)
25/25 [==============================] - 19s 314ms/step
如您所见,模型仍然将我们的令牌视为一只猫,因为这是我们用来初始化自定义令牌的种子令牌。
现在,为了开始训练,我们可以像其他 Keras 模型一样简单地 compile()
我们的模型。在此之前,我们还为训练实例化了一个噪声调度器,并配置了诸如学习率和优化器等训练参数。
noise_scheduler = NoiseScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
train_timesteps=1000,
)
trainer = StableDiffusionFineTuner(stable_diffusion, noise_scheduler, name="trainer")
EPOCHS = 50
learning_rate = keras.optimizers.schedules.CosineDecay(
initial_learning_rate=1e-4, decay_steps=train_ds.cardinality() * EPOCHS
)
optimizer = keras.optimizers.Adam(
weight_decay=0.004, learning_rate=learning_rate, epsilon=1e-8, global_clipnorm=10
)
trainer.compile(
optimizer=optimizer,
# 我们在训练步骤中手动执行了缩减,因此这里不需要。
loss=keras.losses.MeanSquaredError(reduction="none"),
)
为了监控训练,我们可以生成一个 keras.callbacks.Callback
以使用我们的自定义令牌每个周期生成一些图像。
我们创建了三个具有不同提示的回调,以便我们可以看到它们在训练过程中的进展。我们使用固定种子,以便容易查看学习到的令牌的进展。
class GenerateImages(keras.callbacks.Callback):
def __init__(
self, stable_diffusion, prompt, steps=50, frequency=10, seed=None, **kwargs
):
super().__init__(**kwargs)
self.stable_diffusion = stable_diffusion
self.prompt = prompt
self.seed = seed
self.frequency = frequency
self.steps = steps
def on_epoch_end(self, epoch, logs):
if epoch % self.frequency == 0:
images = self.stable_diffusion.text_to_image(
self.prompt, batch_size=3, num_steps=self.steps, seed=self.seed
)
plot_images(
images,
)
cbs = [
GenerateImages(
stable_diffusion, prompt=f"一幅 {placeholder_token} 的油画", seed=1337
),
GenerateImages(
stable_diffusion, prompt=f"甘道夫灰衣作为一个 {placeholder_token}", seed=1337
),
GenerateImages(
stable_diffusion,
prompt=f"两个 {placeholder_token} 结婚,照片真实感,高质量",
seed=1337,
),
]
现在,剩下的就是调用 model.fit()
!
trainer.fit(
train_ds,
epochs=EPOCHS,
callbacks=cbs,
)
</div>
![png](/img/examples/generative/fine_tune_via_textual_inversion/fine_tune_via_textual_inversion_39_2.png)
![png](/img/examples/generative/fine_tune_via_textual_inversion/fine_tune_via_textual_inversion_39_3.png)
![png](/img/examples/generative/fine_tune_via_textual_inversion/fine_tune_via_textual_inversion_39_4.png)
![png](/img/examples/generative/fine_tune_via_textual_inversion/fine_tune_via_textual_inversion_39_5.png)
![png](/img/examples/generative/fine_tune_via_textual_inversion/fine_tune_via_textual_inversion_39_6.png)
![png](/img/examples/generative/fine_tune_via_textual_inversion/fine_tune_via_textual_inversion_39_7.png)
![png](/img/examples/generative/fine_tune_via_textual_inversion/fine_tune_via_textual_inversion_39_8.png)
![png](/img/examples/generative/fine_tune_via_textual_inversion/fine_tune_via_textual_inversion_39_9.png)
![png](/img/examples/generative/fine_tune_via_textual_inversion/fine_tune_via_textual_inversion_39_10.png)
![png](/img/examples/generative/fine_tune_via_textual_inversion/fine_tune_via_textual_inversion_39_11.png)
![png](/img/examples/generative/fine_tune_via_textual_inversion/fine_tune_via_textual_inversion_39_12.png)
![png](/img/examples/generative/fine_tune_via_textual_inversion/fine_tune_via_textual_inversion_39_13.png)
![png](/img/examples/generative/fine_tune_via_textual_inversion/fine_tune_via_textual_inversion_39_14.png)
![png](/img/examples/generative/fine_tune_via_textual_inversion/fine_tune_via_textual_inversion_39_15.png)
![png](/img/examples/generative/fine_tune_via_textual_inversion/fine_tune_via_textual_inversion_39_16.png)
看到模型如何随着时间学习我们的新标记真的很有趣。尽情尝试,看看如何调整训练参数和您的训练数据集,以生成最佳图像!
---
## 测试微调模型
现在是非常有趣的部分。我们为我们的自定义标记学习了一个标记嵌入,因此现在我们可以像对待任何其他标记一样使用稳定扩散生成图像!
这里有一些有趣的示例提示来帮助您入门,以及我们猫玩具标记的示例输出!
```python
generated = stable_diffusion.text_to_image(
f"Gandalf as a {placeholder_token} fantasy art drawn by disney concept artists, "
"golden colour, high quality, highly detailed, elegant, sharp focus, concept art, "
"character concepts, digital painting, mystery, adventure",
batch_size=3,
)
plot_images(generated)
25/25 [==============================] - 8s 316ms/step
generated = stable_diffusion.text_to_image(
f"A masterpiece of a {placeholder_token} crying out to the heavens. "
f"Behind the {placeholder_token}, an dark, evil shade looms over it - sucking the "
"life right out of it.",
batch_size=3,
)
plot_images(generated)
25/25 [==============================] - 8s 314ms/step
generated = stable_diffusion.text_to_image(
f"An evil {placeholder_token}.", batch_size=3
)
plot_images(generated)
25/25 [==============================] - 8s 322ms/step
generated = stable_diffusion.text_to_image(
f"A mysterious {placeholder_token} approaches the great pyramids of egypt.",
batch_size=3,
)
plot_images(generated)
25/25 [==============================] - 8s 315ms/step