代码示例 / 计算机视觉 / 调查视觉变换器表示

调查视觉变换器表示

作者: Aritra Roy Gosthipaty, Sayak Paul (共同贡献)
创建日期: 2022/04/12
最后修改日期: 2023/11/20
描述: 研究不同视觉变换器变体学习到的表示。

在 Colab 中查看 GitHub 源代码


介绍

在这个例子中,我们研究不同视觉变换器(ViT)模型学习到的表示。我们这个例子的主要目标是提供关于什么使得 ViT 能够从图像数据中学习的见解。特别地,本示例讨论了几个不同的 ViT 分析工具的实现。

注意: 当我们提到“视觉变换器”时,我们是指涉及变换器块的计算机视觉架构(Vaswani 等人),而不一定是最初的视觉变换器模型(Dosovitskiy 等人)。


考虑的模型

自从最初的视觉变换器问世以来,计算机视觉界已经见证了多个不同的 ViT 变体在各个方面对原始模型的改进:训练改进、架构改进等等。在这个例子中,我们考虑以下 ViT 模型系列:

  • 使用 ImageNet-1k 和 ImageNet-21k 数据集进行监督预训练的 ViT(Dosovitskiy 等人
  • 仅使用 ImageNet-1k 数据集进行监督预训练,但采用更多的正则化和蒸馏的 ViT(Touvron 等人)(DeiT)。
  • 使用自监督预训练的 ViT(Caron 等人)(DINO)。

由于这些预训练模型在 Keras 中未实现,我们首先尽可能忠实地实现了它们。然后,我们用官方的预训练参数填充它们。最后,我们在 ImageNet-1k 验证集上评估了我们的实现,以确保评估结果与原始实现相匹配。我们实现的细节可在 这个仓库 中找到。

为了保持示例简洁,我们不会详细配对每个模型和分析方法。我们将在各自的部分提供说明,以便您可以捡起这些碎片。

要在 Google Colab 上运行此示例,我们需要更新 gdown 库,如下所示:

pip install -U gdown -q

导入

import os

os.environ["KERAS_BACKEND"] = "tensorflow"

import zipfile
from io import BytesIO

import cv2
import matplotlib.pyplot as plt
import numpy as np
import requests

from PIL import Image
from sklearn.preprocessing import MinMaxScaler
import keras
from keras import ops

常量

RESOLUTION = 224
PATCH_SIZE = 16
GITHUB_RELEASE = "https://github.com/sayakpaul/probing-vits/releases/download/v1.0.0/probing_vits.zip"
FNAME = "probing_vits.zip"
MODELS_ZIP = {
    "vit_dino_base16": "Probing_ViTs/vit_dino_base16.zip",
    "vit_b16_patch16_224": "Probing_ViTs/vit_b16_patch16_224.zip",
    "vit_b16_patch16_224-i1k_pretrained": "Probing_ViTs/vit_b16_patch16_224-i1k_pretrained.zip",
}

数据工具

对于原始的 ViT 模型,输入图像需要缩放到范围[-1, 1]。对于前面提到的其他模型系列,我们需要使用 ImageNet-1k 训练集的通道均值和标准差对图像进行归一化。

crop_layer = keras.layers.CenterCrop(RESOLUTION, RESOLUTION)
norm_layer = keras.layers.Normalization(
    mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
    variance=[(0.229 * 255) ** 2, (0.224 * 255) ** 2, (0.225 * 255) ** 2],
)
rescale_layer = keras.layers.Rescaling(scale=1.0 / 127.5, offset=-1)


def preprocess_image(image, model_type, size=RESOLUTION):
    # 将图像转换为numpy数组并增加批次维度。
    image = np.array(image)
    image = ops.expand_dims(image, 0)

    # 如果模型类型是vit,将图像重新缩放到[-1, 1]。
    if model_type == "original_vit":
        image = rescale_layer(image)

    # 使用双线性插值调整图像大小。
    resize_size = int((256 / 224) * size)
    image = ops.image.resize(image, (resize_size, resize_size), interpolation="bicubic")

    # 裁剪图像。
    image = crop_layer(image)

    # 如果模型类型是DeiT或DINO,规范化图像。
    if model_type != "original_vit":
        image = norm_layer(image)

    return ops.convert_to_numpy(image)


def load_image_from_url(url, model_type):
    # 归功于:Willi Gierke
    response = requests.get(url)
    image = Image.open(BytesIO(response.content))
    preprocessed_image = preprocess_image(image, model_type)
    return image, preprocessed_image

加载测试图像并显示

# ImageNet-1k 标签映射文件并加载它。

mapping_file = keras.utils.get_file(
    origin="https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt"
)

with open(mapping_file, "r") as f:
    lines = f.readlines()
imagenet_int_to_str = [line.rstrip() for line in lines]

img_url = "https://dl.fbaipublicfiles.com/dino/img.png"
image, preprocessed_image = load_image_from_url(img_url, model_type="original_vit")

plt.imshow(image)
plt.axis("off")
plt.show()

png


加载模型

zip_path = keras.utils.get_file(
    fname=FNAME,
    origin=GITHUB_RELEASE,
)

with zipfile.ZipFile(zip_path, "r") as zip_ref:
    zip_ref.extractall("./")

os.rename("Probing ViTs", "Probing_ViTs")


def load_model(model_path: str) -> keras.Model:
    with zipfile.ZipFile(model_path, "r") as zip_ref:
        zip_ref.extractall("Probing_ViTs/")
    model_name = model_path.split(".")[0]

    inputs = keras.Input((RESOLUTION, RESOLUTION, 3))
    model = keras.layers.TFSMLayer(model_name, call_endpoint="serving_default")
    outputs = model(inputs, training=False)

    return keras.Model(inputs, outputs=outputs)


vit_base_i21k_patch16_224 = load_model(MODELS_ZIP["vit_b16_patch16_224-i1k_pretrained"])
print("模型已加载。")
模型已加载。

关于模型的更多信息

此模型在 ImageNet-21k 数据集上进行了预训练,然后在 ImageNet-1k 数据集上进行微调。要了解我们如何在 TensorFlow 中开发此模型(带有来自 该来源 的预训练权重),请参考 此笔记本


使用模型进行常规推断

我们现在在测试图像上使用加载的模型进行推断。

def split_prediction_and_attention_scores(outputs):
    predictions = outputs["output_1"]
    attention_score_dict = {}
    for key, value in outputs.items():
        if key.startswith("output_2_"):
            attention_score_dict[key[len("output_2_") :]] = value
    return predictions, attention_score_dict


predictions, attention_score_dict = split_prediction_and_attention_scores(
    vit_base_i21k_patch16_224.predict(preprocessed_image)
)
predicted_label = imagenet_int_to_str[int(np.argmax(predictions))]
print(predicted_label)
 1/1 ━━━━━━━━━━━━━━━━━━━━ 5s 5s/step
金刚鹦鹉

WARNING: 所有日志消息在 absl::InitializeLog() 被调用之前都会写入 STDERR
I0000 00:00:1700526824.965785   75784 device_compiler.h:187] 使用 XLA 编译群集!此行最多在进程的生命周期内记录一次。

attention_score_dict 包含来自每个 Transformer 块的每个注意力头的注意力分数(经过 softmax 的输出)。


方法 I:均值注意力距离

Dosovitskiy 等Raghu 等 使用一种称为“均值注意力距离”的度量,从不同 Transformer 块的每个注意力头中了解局部和全局信息如何流入 Vision Transformers。

均值注意力距离被定义为查询词元与其他词元之间的距离乘以注意力权重。因此,对于单个图像

  • 我们提取从图像中提取的单个补丁(词元),
  • 计算它们的几何距离,以及
  • 用注意力得分相乘。

注意力得分是在推断模式下通过网络正向传递图像后计算的。以下图可能会帮助您更好地理解这个过程。

此动画由 Ritwik Raha 创建。

def compute_distance_matrix(patch_size, num_patches, length):
    distance_matrix = np.zeros((num_patches, num_patches))
    for i in range(num_patches):
        for j in range(num_patches):
            if i == j:  # 零距离
                continue

            xi, yi = (int(i / length)), (i % length)
            xj, yj = (int(j / length)), (j % length)
            distance_matrix[i, j] = patch_size * np.linalg.norm([xi - xj, yi - yj])

    return distance_matrix


def compute_mean_attention_dist(patch_size, attention_weights, model_type):
    num_cls_tokens = 2 if "distilled" in model_type else 1

    # attention_weights 的形状 = (batch, num_heads, num_patches, num_patches)
    attention_weights = attention_weights[
        ..., num_cls_tokens:, num_cls_tokens:
    ]  # 去掉 CLS token
    num_patches = attention_weights.shape[-1]
    length = int(np.sqrt(num_patches))
    assert length**2 == num_patches, "Num patches 不是完全平方数"

    distance_matrix = compute_distance_matrix(patch_size, num_patches, length)
    h, w = distance_matrix.shape

    distance_matrix = distance_matrix.reshape((1, 1, h, w))
    # attention_weights 在最后一个轴上的总和为1
    # 这是因为它们是原始逻辑值的 softmax
    # (attention_weights * distance_matrix) 的求和
    # 应该得到每个 token 的平均距离。
    mean_distances = attention_weights * distance_matrix
    mean_distances = np.sum(
        mean_distances, axis=-1
    )  # 在最后一个轴上求和以获得每个 token 的平均距离
    mean_distances = np.mean(
        mean_distances, axis=-1
    )  # 现在在所有的 tokens 之间取平均

    return mean_distances

感谢来自 Google 的 Simon Kornblith 给予我们这段代码的帮助。代码可以在 此处 找到。现在让我们使用这些工具生成我们的加载模型和测试图像的注意力距离图。

# 为每个 Transformer 块构建平均距离。
mean_distances = {
    f"{name}_mean_dist": compute_mean_attention_dist(
        patch_size=PATCH_SIZE,
        attention_weights=attention_weight,
        model_type="original_vit",
    )
    for name, attention_weight in attention_score_dict.items()
}

# 从平均距离输出中获取头的数量。
num_heads = mean_distances["transformer_block_0_att_mean_dist"].shape[-1]

# 打印形状
print(f"Num Heads: {num_heads}.")

plt.figure(figsize=(9, 9))

for idx in range(len(mean_distances)):
    mean_distance = mean_distances[f"transformer_block_{idx}_att_mean_dist"]
    x = [idx] * num_heads
    y = mean_distance[0, :]
    plt.scatter(x=x, y=y, label=f"transformer_block_{idx}")

plt.legend(loc="lower right")
plt.xlabel("Attention Head", fontsize=14)
plt.ylabel("Attention Distance", fontsize=14)
plt.title("vit_base_i21k_patch16_224", fontsize=14)
plt.grid()
plt.show()
Num Heads: 12.

png

检查图表

自注意力如何跨越输入空间?它们是局部关注输入区域还是全局?

自注意力的承诺是能够学习上下文依赖关系,使模型能够关注与目标最相关的输入区域。从以上图表中,我们可以注意到不同的注意力头产生不同的注意力距离,表明它们利用图像的局部和全球信息。但随着我们在 Transformer 块中深入,头部往往更关注全局聚合信息。

受到 Raghu et al. 的启发,我们在从 ImageNet-1k 验证集随机选取的 1000 张图像上计算了平均注意力距离,并为开头提到的所有模型重复了这个过程。有趣的是,我们注意到以下情况:

  • 用更大数据集进行预训练有助于获得更多的全局注意力跨度:
在 ImageNet-21k 上预训练
在 ImageNet-1k 上微调
在 ImageNet-1k 上预训练
  • 从 CNN 中蒸馏的 ViT 通常具有更少的全局注意力跨度:
没有蒸馏(来自 DeiT 的 ViT B-16) 从 DeiT 蒸馏的 ViT B-16

要复制这些图表,请参考 这个笔记本


方法 II:注意力展开

Abnar et al. 引入了“注意力展开”来量化信息如何通过 Transformer 块的自注意力层流动。原始 ViT 的作者使用此方法调查学习到的表示,陈述:

简而言之,我们对 ViTL/16 的注意力权重在所有头部上进行了平均,然后递归地乘以所有层的权重矩阵。这考虑了通过所有层跨标记的注意力混合。

我们使用 这个笔记本 并修改了其中的注意力展开代码以兼容我们的模型。

def attention_rollout_map(image, attention_score_dict, model_type):
    num_cls_tokens = 2 if "distilled" in model_type else 1

    # 堆叠来自各个 Transformer 块的单个注意力矩阵。
    attn_mat = ops.stack([attention_score_dict[k] for k in attention_score_dict.keys()])
    attn_mat = ops.squeeze(attn_mat, axis=1)

    # 在所有头部之间平均注意力权重。
    attn_mat = ops.mean(attn_mat, axis=1)

    # 考虑残差连接,我们在注意力矩阵中添加一个单位矩阵并重新归一化权重。
    residual_attn = ops.eye(attn_mat.shape[1])
    aug_attn_mat = attn_mat + residual_attn
    aug_attn_mat = aug_attn_mat / ops.sum(aug_attn_mat, axis=-1)[..., None]
    aug_attn_mat = ops.convert_to_numpy(aug_attn_mat)

    # 递归地乘以权重矩阵。
    joint_attentions = np.zeros(aug_attn_mat.shape)
    joint_attentions[0] = aug_attn_mat[0]

    for n in range(1, aug_attn_mat.shape[0]):
        joint_attentions[n] = np.matmul(aug_attn_mat[n], joint_attentions[n - 1])

    # 从输出 token 到输入空间的注意力。
    v = joint_attentions[-1]
    grid_size = int(np.sqrt(aug_attn_mat.shape[-1]))
    mask = v[0, num_cls_tokens:].reshape(grid_size, grid_size)
    mask = cv2.resize(mask / mask.max(), image.size)[..., np.newaxis]
    result = (mask * image).astype("uint8")
    return result

现在让我们使用这些工具生成一个基于我们在“使用模型进行常规推理”部分中的结果的注意力图。以下是下载每个单独模型的链接:

attn_rollout_result = attention_rollout_map(
    image, attention_score_dict, model_type="original_vit"
)

fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(8, 10))
fig.suptitle(f"Predicted label: {predicted_label}.", fontsize=20)

_ = ax1.imshow(image)
_ = ax2.imshow(attn_rollout_result)
ax1.set_title("输入图像", fontsize=16)
ax2.set_title("注意力图", fontsize=16)
ax1.axis("off")
ax2.axis("off")

fig.tight_layout()
fig.subplots_adjust(top=1.35)
fig.show()

png

检查图表

我们如何量化通过注意力层传播的信息流?

我们注意到模型能够将注意力集中在输入图像的显著部分。我们鼓励您将这种方法应用于我们提到的其他模型并比较结果。注意力传播图将根据模型训练时使用的任务和增强而有所不同。我们观察到DeiT具有最佳的传播图,这可能是由于其增强机制。


方法 III:注意力热图

探测 Vision Transformer 表示的一种简单而有效的方法是将注意力图可视化在输入图像上。这有助于形成模型关注点的直觉。我们使用 DINO 模型来实现此目的,因为它产生了更好的注意力热图。

# 加载模型。
vit_dino_base16 = load_model(MODELS_ZIP["vit_dino_base16"])
print("模型已加载。")

# 对同一图像进行处理,但进行标准化。
img_url = "https://dl.fbaipublicfiles.com/dino/img.png"
image, preprocessed_image = load_image_from_url(img_url, model_type="dino")

# 获取预测结果。
predictions, attention_score_dict = split_prediction_and_attention_scores(
    vit_dino_base16.predict(preprocessed_image)
)
模型已加载。
 1/1 ━━━━━━━━━━━━━━━━━━━━ 4s 4s/step

一个 Transformer 块由多个头组成。Transformer 块中的每个头将输入数据投影到不同的子空间。这有助于每个单独的头关注图像的不同部分。因此,分别可视化每个注意力头图是有意义的,以了解每个头关注的内容。

注释

  • 以下代码已从原始 DINO 代码库进行复制修改。
  • 在这里,我们获取最后一个 Transformer 块的注意力图。
  • DINO 使用自监督目标进行预训练。
def attention_heatmap(attention_score_dict, image, model_type="dino"):
    num_tokens = 2 if "distilled" in model_type else 1

    # 按深度顺序排序 Transformer 块。
    attention_score_list = list(attention_score_dict.keys())
    attention_score_list.sort(key=lambda x: int(x.split("_")[-2]), reverse=True)

    # 处理注意力图以进行叠加。
    w_featmap = image.shape[2] // PATCH_SIZE
    h_featmap = image.shape[1] // PATCH_SIZE
    attention_scores = attention_score_dict[attention_score_list[0]]

    # 取 CLS 令牌的表示。
    attentions = attention_scores[0, :, 0, num_tokens:].reshape(num_heads, -1)

    # 重塑注意力评分以类似迷你补丁的形式。
    attentions = attentions.reshape(num_heads, w_featmap, h_featmap)
    attentions = attentions.transpose((1, 2, 0))

    # 将注意力补丁的大小调整为 224x224(224: 14x16)。
    attentions = ops.image.resize(
        attentions, size=(h_featmap * PATCH_SIZE, w_featmap * PATCH_SIZE)
    )
    return attentions

我们可以使用与 DINO 推理相同的图像和从结果中提取的 attention_score_dict

# De-normalize the image for visual clarity.
in1k_mean = np.array([0.485 * 255, 0.456 * 255, 0.406 * 255])
in1k_std = np.array([0.229 * 255, 0.224 * 255, 0.225 * 255])
preprocessed_img_orig = (preprocessed_image * in1k_std) + in1k_mean
preprocessed_img_orig = preprocessed_img_orig / 255.0
preprocessed_img_orig = ops.convert_to_numpy(ops.clip(preprocessed_img_orig, 0.0, 1.0))

# Generate the attention heatmaps.
attentions = attention_heatmap(attention_score_dict, preprocessed_img_orig)

# Plot the maps.
fig, axes = plt.subplots(nrows=3, ncols=4, figsize=(13, 13))
img_count = 0

for i in range(3):
    for j in range(4):
        if img_count < len(attentions):
            axes[i, j].imshow(preprocessed_img_orig[0])
            axes[i, j].imshow(attentions[..., img_count], cmap="inferno", alpha=0.6)
            axes[i, j].title.set_text(f"Attention head: {img_count}")
            axes[i, j].axis("off")
            img_count += 1

png

检查图表

我们如何定性地评估注意力权重?

Transformer块的注意力权重是在键(key)和查询(query)之间计算的。权重量化了键对查询的重要性。在ViTs中,键和查询来自同一张图像,因此权重决定了图像的哪个部分是重要的。

将注意力权重叠加在图像上可以很好地直观显示出Transformer认为重要的图像部分。这个图表定性地评估了注意力权重的目的。


方法IV:可视化学习到的投影过滤器

在提取非重叠的补丁后,ViTs在其空间维度上将这些补丁展平,然后进行线性投影。人们会想,这些投影看起来如何?以下,我们取ViT B-16模型并可视化其学习到的投影。

def extract_weights(model, name):
    for variable in model.weights:
        if variable.name.startswith(name):
            return variable.numpy()


# 提取投影。
projections = extract_weights(vit_base_i21k_patch16_224, "conv_projection/kernel")
projection_dim = projections.shape[-1]
patch_h, patch_w, patch_channels = projections.shape[:-1]

# 缩放投影。
scaled_projections = MinMaxScaler().fit_transform(
    projections.reshape(-1, projection_dim)
)

# 将缩放后的投影重塑,使前
# 三个维度类似于图像。
scaled_projections = scaled_projections.reshape(patch_h, patch_w, patch_channels, -1)

# 可视化学习到的前128个
# 投影过滤器。
fig, axes = plt.subplots(nrows=8, ncols=16, figsize=(13, 8))
img_count = 0
limit = 128

for i in range(8):
    for j in range(16):
        if img_count < limit:
            axes[i, j].imshow(scaled_projections[..., img_count])
            axes[i, j].axis("off")
            img_count += 1

fig.tight_layout()
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats或[0..255]为整数)。

png

检查图表

投影过滤器学到了什么?

当可视化时,卷积神经网络的核显示它们在图像中寻找的模式。这可能是圆形,有时是线条——当在ConvNet的后期组合在一起时,过滤器转变为更复杂的形状。我们发现这种ConvNet核与ViT的投影过滤器之间有明显的相似性。


方法V:可视化位置嵌入

Transformer是排列不变的。这意味着它们不考虑输入标记的空间位置。为了解决这一限制,我们将位置信息添加到输入标记中。

位置信息可以采取学习到的位置嵌入或手工制作的常量嵌入形式。在我们的案例中,所有三种变体的ViTs都具有学习到的位置嵌入。

在这一部分,我们可视化学习到的位置嵌入之间的相似性。以下,我们取ViT B-16模型,通过取它们的点积可视化位置嵌入的相似性。

position_embeddings = extract_weights(vit_base_i21k_patch16_224, "pos_embedding")

# 丢弃批处理维度和
# cls标记的位置嵌入。
position_embeddings = position_embeddings.squeeze()[1:, ...]

similarity = position_embeddings @ position_embeddings.T
plt.imshow(similarity, cmap="inferno")
plt.show()

png

检查图表

位置嵌入告诉我们什么?

该图表具有明显的对角模式。主对角线是最亮的,表示一个位置与它自己最相似。一个有趣的模式是重复的对角线。重复的模式描绘了一个正弦函数,这与此前提出的内容本质上相近。 Vaswani et. al. 作为手工特征。


注释

  • DINO 扩展了注意力热图生成过程到视频。我们也 应用了我们的 DINO 实现于一系列视频,并获得了类似的结果。这是一个显示注意力热图的视频:

    dino

  • Raghu et al. 使用一系列技术来 调查 ViT 学习的表示,并与 ResNet 的结果进行比较。我们强烈推荐阅读他们的工作。
  • 为了编写这个示例,我们开发了 这个仓库 来指导我们的读者,以便他们 可以轻松地复制实验并扩展它们。
  • 另一个您可能会对此感兴趣的仓库是 vit-explain
  • 还可以使用我们的 Hugging Face 空间绘制自定义图像的注意力传播和注意力热图。
注意力热图 注意力传播
Generic badge Generic badge

致谢