代码示例 / 计算机视觉 / 带有 LayerScale 的类注意力图像变换器

带有 LayerScale 的类注意力图像变换器

作者: Sayak Paul
创建日期: 2022/09/19
最后修改日期: 2022/11/21
描述: 实现带有类注意力和 LayerScale 的图像变换器。

在 Colab 中查看 GitHub 源代码


介绍

在本教程中,我们实现了 CaiT(类注意力图像变换器),该方法由 Touvron 等人提出于 Going deeper with Image Transformers。深度扩展,即通过增加模型深度以获得更好的性能和泛化能力,对于卷积神经网络(例如 Tan et al.Dollár et al.)相当成功。但将相同的模型扩展原则应用于视觉变换器(Dosovitskiy et al.)并不具有同样的效果——它们的性能在深度扩展时迅速饱和。请注意,这里的一个假设是,在进行模型扩展时,底层预训练数据集始终保持固定。

在 CaiT 论文中,作者探讨了这一现象,并对原始 ViT(视觉变换器)架构提出了修改以缓解此问题。

本教程的结构如下:

  • 实现 CaiT 的各个模块
  • 将所有模块汇总以创建 CaiT 模型
  • 加载预训练的 CaiT 模型
  • 获取预测结果
  • 可视化 CaiT 的不同注意力层

假设读者已经熟悉视觉变换器。以下是 Keras 中视觉变换器的实现:使用视觉变换器进行图像分类


导入

import os

os.environ["KERAS_BACKEND"] = "tensorflow"

import io
import typing
from urllib.request import urlopen

import matplotlib.pyplot as plt
import numpy as np
import PIL
import keras
from keras import layers
from keras import ops

LayerScale 层

我们首先实现一个 LayerScale 层,这是 CaiT 论文中提出的两种修改之一。

当增加 ViT 模型的深度时,它们会遇到优化不稳定的问题,并最终无法收敛。每个变换器模块中的残差连接引入了信息瓶颈。当深度增加时,这个瓶颈可能会迅速爆炸,偏离底层模型的优化路径。

以下方程表示在变换器模块内添加残差连接的位置:

其中,SA 表示自注意力,FFN 表示前馈网络,eta 表示 LayerNorm 操作(Ba et al.)。

LayerScale 的正式实现如下:

其中,lambda 是可学习的参数,并且初始值非常小({0.1, 1e-5, 1e-6})。diag 表示对角矩阵。

直观上,LayerScale 帮助控制残差分支的贡献。LayerScale 的可学习参数初始化为小值,以使分支表现得像恒等函数,然后让它们在训练过程中找出交互的程度。对角矩阵还通过逐通道应用帮助控制残差输入的各个维度的贡献。

LayerScale 的实际实现比听上去要简单。

class LayerScale(layers.Layer):
    """在 CaiT 中引入的 LayerScale: https://arxiv.org/abs/2103.17239。

    参数:
        init_values (float): 初始化 LayerScale 的对角矩阵的值。
        projection_dim (int): LayerScale 中使用的投影维度。
    """

    def __init__(self, init_values: float, projection_dim: int, **kwargs):
        super().__init__(**kwargs)
        self.gamma = self.add_weight(
            shape=(projection_dim,),
            initializer=keras.initializers.Constant(init_values),
        )

    def call(self, x, training=False):
        return x * self.gamma

随机深度层

自其引入以来(Huang et al.),随机深度已成为几乎所有现代神经网络架构的热门组件。 CaiT 也不例外。讨论随机深度超出本笔记本的范围。如果您需要复习,可以参考 这个资源

class StochasticDepth(layers.Layer):
    """随机深度层 (https://arxiv.org/abs/1603.09382)。

    参考文献:
        https://github.com/rwightman/pytorch-image-models
    """

    def __init__(self, drop_prob: float, **kwargs):
        super().__init__(**kwargs)
        self.drop_prob = drop_prob
        self.seed_generator = keras.random.SeedGenerator(1337)

    def call(self, x, training=False):
        if training:
            keep_prob = 1 - self.drop_prob
            shape = (ops.shape(x)[0],) + (1,) * (len(x.shape) - 1)
            random_tensor = keep_prob + ops.random.uniform(
                shape, minval=0, maxval=1, seed=self.seed_generator
            )
            random_tensor = ops.floor(random_tensor)
            return (x / keep_prob) * random_tensor
        return x

类注意力

原始ViT使用自注意力(SA)层来建模图像补丁与可学习的CLS标记之间的交互。CaiT的作者提出解耦负责关注图像补丁和CLS标记的注意力层。

在任何区分性任务(例如分类)中使用ViTs时,我们通常会获取属于CLS标记的表示,然后将其传递给特定任务的头部。这与在卷积神经网络中通常使用的全局平均池化不同。

CLS标记与其他图像补丁之间的交互通过自注意力层均匀处理。正如CaiT的作者所指出的,这种设置具有纠缠效应。一方面,自注意力层负责建模图像补丁。另一方面,它们还负责通过CLS标记总结建模的信息,使其对学习目标有用。

为了解开这两者,作者提出:

  • 在网络的较晚阶段引入CLS标记。
  • 通过一组单独的注意力层建模CLS标记与与图像补丁相关的表示之间的交互。作者将此称为类注意力(CA)。

下面的图(取自原始论文)描绘了这个想法:

这是通过将CLS标记嵌入视为CA层中的查询来实现的。CLS标记嵌入和图像补丁嵌入作为键以及值输入。

注意 "嵌入" 和 "表示" 在这里可以互换使用。

class ClassAttention(layers.Layer):
    """类注意力,参考文献 CaiT: https://arxiv.org/abs/2103.17239。

    参数:
        projection_dim (int): 注意力的查询、键和值的投影维度。
        num_heads (int): 注意力头的数量。
        dropout_rate (float): 用于注意力得分和最终投影输出的 dropout 率。
    """

    def __init__(
        self, projection_dim: int, num_heads: int, dropout_rate: float, **kwargs
    ):
        super().__init__(**kwargs)
        self.num_heads = num_heads

        head_dim = projection_dim // num_heads
        self.scale = head_dim**-0.5

        self.q = layers.Dense(projection_dim)
        self.k = layers.Dense(projection_dim)
        self.v = layers.Dense(projection_dim)
        self.attn_drop = layers.Dropout(dropout_rate)
        self.proj = layers.Dense(projection_dim)
        self.proj_drop = layers.Dropout(dropout_rate)

    def call(self, x, training=False):
        batch_size, num_patches, num_channels = (
            ops.shape(x)[0],
            ops.shape(x)[1],
            ops.shape(x)[2],
        )

        # 查询投影。`cls_token` 嵌入作为查询。
        q = ops.expand_dims(self.q(x[:, 0]), axis=1)
        q = ops.reshape(
            q, (batch_size, 1, self.num_heads, num_channels // self.num_heads)
        )  # 形状: (batch_size, 1, num_heads, 每个头的维度)
        q = ops.transpose(q, axes=[0, 2, 1, 3])
        scale = ops.cast(self.scale, dtype=q.dtype)
        q = q * scale

        # 键投影。补丁嵌入和 cls 嵌入用作键。
        k = self.k(x)
        k = ops.reshape(
            k, (batch_size, num_patches, self.num_heads, num_channels // self.num_heads)
        )  # 形状: (batch_size, num_tokens, num_heads, 每个头的维度)
        k = ops.transpose(k, axes=[0, 2, 3, 1])

        # 值投影。补丁嵌入和 cls 嵌入用作值。
        v = self.v(x)
        v = ops.reshape(
            v, (batch_size, num_patches, self.num_heads, num_channels // self.num_heads)
        )
        v = ops.transpose(v, axes=[0, 2, 1, 3])

        # 计算 cls_token 嵌入与补丁嵌入之间的注意力得分。
        attn = ops.matmul(q, k)
        attn = ops.nn.softmax(attn, axis=-1)
        attn = self.attn_drop(attn, training=training)

        x_cls = ops.matmul(attn, v)
        x_cls = ops.transpose(x_cls, axes=[0, 2, 1, 3])
        x_cls = ops.reshape(x_cls, (batch_size, 1, num_channels))
        x_cls = self.proj(x_cls)
        x_cls = self.proj_drop(x_cls, training=training)

        return x_cls, attn

对话头注意力

CaiT 的作者使用了对话头注意力 (Shazeer et al.) 而不是原始 Transformer 论文中使用的普通缩放点积多头注意力 (Vaswani et al.)。 他们在 softmax 操作之前和之后引入了两个线性投影,以获得更好的结果。

有关对话头注意力和普通注意力机制的更严格处理,请参考各自的论文(链接在上面)。

class TalkingHeadAttention(layers.Layer):
    """按照 CaiT 提出的对话头注意力: https://arxiv.org/abs/2003.02436。

    Args:
        projection_dim (int): 注意力的查询、键和值的投影维度。
        num_heads (int): 注意力头的数量。
        dropout_rate (float): 用于注意力得分和最终投影输出的 dropout 率。
    """

    def __init__(
        self, projection_dim: int, num_heads: int, dropout_rate: float, **kwargs
    ):
        super().__init__(**kwargs)

        self.num_heads = num_heads

        head_dim = projection_dim // self.num_heads

        self.scale = head_dim**-0.5

        self.qkv = layers.Dense(projection_dim * 3)
        self.attn_drop = layers.Dropout(dropout_rate)

        self.proj = layers.Dense(projection_dim)

        self.proj_l = layers.Dense(self.num_heads)
        self.proj_w = layers.Dense(self.num_heads)

        self.proj_drop = layers.Dropout(dropout_rate)

    def call(self, x, training=False):
        B, N, C = ops.shape(x)[0], ops.shape(x)[1], ops.shape(x)[2]

        # 一次性对输入进行投影。
        qkv = self.qkv(x)

        # 重新塑形投影输出,以便根据查询、键和值投影进行分隔。
        qkv = ops.reshape(qkv, (B, N, 3, self.num_heads, C // self.num_heads))

        # 转置以使 `num_heads` 成为领先维度。
        # 有助于更好地分隔表示子空间。
        qkv = ops.transpose(qkv, axes=[2, 0, 3, 1, 4])
        scale = ops.cast(self.scale, dtype=qkv.dtype)
        q, k, v = qkv[0] * scale, qkv[1], qkv[2]

        # 获取原始注意力得分。
        attn = ops.matmul(q, ops.transpose(k, axes=[0, 1, 3, 2]))

        # 查询和键投影之间相似性线性投影。
        attn = self.proj_l(ops.transpose(attn, axes=[0, 2, 3, 1]))

        # 归一化注意力得分。
        attn = ops.transpose(attn, axes=[0, 3, 1, 2])
        attn = ops.nn.softmax(attn, axis=-1)

        # 对 softmax 后的得分进行线性投影。
        attn = self.proj_w(ops.transpose(attn, axes=[0, 2, 3, 1]))
        attn = ops.transpose(attn, axes=[0, 3, 1, 2])
        attn = self.attn_drop(attn, training=training)

        # 按照普通注意力机制进行最终投影。
        x = ops.matmul(attn, v)
        x = ops.transpose(x, axes=[0, 2, 1, 3])
        x = ops.reshape(x, (B, N, C))

        x = self.proj(x)
        x = self.proj_drop(x, training=training)

        return x, attn

前馈网络

接下来,我们实现前馈网络,这是 Transformer 块中的一个组件。

def mlp(x, dropout_rate: float, hidden_units: typing.List[int]):
    """Transformer 块的前馈网络(FFN)。"""
    for idx, units in enumerate(hidden_units):
        x = layers.Dense(
            units,
            activation=ops.nn.gelu if idx == 0 else None,
            bias_initializer=keras.initializers.RandomNormal(stddev=1e-6),
        )(x)
        x = layers.Dropout(dropout_rate)(x)
    return x

其他块

在接下来的两个单元中,我们将其余块实现为独立的函数:

  • LayerScaleBlockClassAttention() 返回一个 keras.Model。它是一个带有类注意力、层缩放和随机深度的 Transformer 块。它作用于 CLS 嵌入和图像补丁嵌入。
  • LayerScaleBlock() 返回一个 keras.model。它也是一个仅对图像补丁嵌入进行操作的 Transformer 块。它配备了层缩放和随机深度。
def LayerScaleBlockClassAttention(
    projection_dim: int,
    num_heads: int,
    layer_norm_eps: float,
    init_values: float,
    mlp_units: typing.List[int],
    dropout_rate: float,
    sd_prob: float,
    name: str,
):
    """预归一化变压器块,旨在应用于 cls 令牌的嵌入和图像块的嵌入。

    包含 LayerScale 和随机深度。

    参数:
        projection_dim (int): 用于变压器块和图块投影层的投影维度。
        num_heads (int): 注意力头的数量。
        layer_norm_eps (float): 用于层归一化的 epsilon。
        init_values (float): 用于 LayerScale 中的对角矩阵的初始值。
        mlp_units (List[int]): 用于变压器块的前馈网络的维度。
        dropout_rate (float): 在注意力分数和最终投影输出中使用的丢弃率。
        sd_prob (float): 随机深度率。
        name (str): 块的名称标识符。

    返回:
        A keras.Model 实例。
    """
    x = keras.Input((None, projection_dim))
    x_cls = keras.Input((None, projection_dim))
    inputs = keras.layers.Concatenate(axis=1)([x_cls, x])

    # 类别注意力 (CA)。
    x1 = layers.LayerNormalization(epsilon=layer_norm_eps)(inputs)
    attn_output, attn_scores = ClassAttention(projection_dim, num_heads, dropout_rate)(
        x1
    )
    attn_output = (
        LayerScale(init_values, projection_dim)(attn_output)
        if init_values
        else attn_output
    )
    attn_output = StochasticDepth(sd_prob)(attn_output) if sd_prob else attn_output
    x2 = keras.layers.Add()([x_cls, attn_output])

    # FFN。
    x3 = layers.LayerNormalization(epsilon=layer_norm_eps)(x2)
    x4 = mlp(x3, hidden_units=mlp_units, dropout_rate=dropout_rate)
    x4 = LayerScale(init_values, projection_dim)(x4) if init_values else x4
    x4 = StochasticDepth(sd_prob)(x4) if sd_prob else x4
    outputs = keras.layers.Add()([x2, x4])

    return keras.Model([x, x_cls], [outputs, attn_scores], name=name)


def LayerScaleBlock(
    projection_dim: int,
    num_heads: int,
    layer_norm_eps: float,
    init_values: float,
    mlp_units: typing.List[int],
    dropout_rate: float,
    sd_prob: float,
    name: str,
):
    """预归一化变压器块,旨在应用于图像块的嵌入。

    包含 LayerScale 和随机深度。

        参数:
            projection_dim (int): 用于变压器块和图块投影层的投影维度。
            num_heads (int): 注意力头的数量。
            layer_norm_eps (float): 用于层归一化的 epsilon。
            init_values (float): 用于 LayerScale 中的对角矩阵的初始值。
            mlp_units (List[int]): 用于变压器块的前馈网络的维度。
            dropout_rate (float): 在注意力分数和最终投影输出中使用的丢弃率。
            sd_prob (float): 随机深度率。
            name (str): 块的名称标识符。

    返回:
        A keras.Model 实例。
    """
    encoded_patches = keras.Input((None, projection_dim))

    # 自注意力。
    x1 = layers.LayerNormalization(epsilon=layer_norm_eps)(encoded_patches)
    attn_output, attn_scores = TalkingHeadAttention(
        projection_dim, num_heads, dropout_rate
    )(x1)
    attn_output = (
        LayerScale(init_values, projection_dim)(attn_output)
        if init_values
        else attn_output
    )
    attn_output = StochasticDepth(sd_prob)(attn_output) if sd_prob else attn_output
    x2 = layers.Add()([encoded_patches, attn_output])

    # FFN。
    x3 = layers.LayerNormalization(epsilon=layer_norm_eps)(x2)
    x4 = mlp(x3, hidden_units=mlp_units, dropout_rate=dropout_rate)
    x4 = LayerScale(init_values, projection_dim)(x4) if init_values else x4
    x4 = StochasticDepth(sd_prob)(x4) if sd_prob else x4
    outputs = layers.Add()([x2, x4])

    return keras.Model(encoded_patches, [outputs, attn_scores], name=name)

鉴于所有这些模块,我们现在准备将它们汇总成最终的CaiT模型。


将各部分拼接在一起:CaiT模型

class CaiT(keras.Model):
    """CaiT模型。

    参数:
        projection_dim (int): Transformers模块和补丁投影层中使用的投影维度。
        patch_size (int): 输入图像的补丁大小。
        num_patches (int): 提取图像补丁后的补丁数量。
        init_values (float): 在LayerScale中使用的对角矩阵的初始值。
        mlp_units: (List[int]): 在Transformer模块中使用的前馈网络的维度。
        sa_ffn_layers (int): 自注意力Transformer模块的数量。
        ca_ffn_layers (int): 类注意力Transformer模块的数量。
        num_heads (int): 注意力头的数量。
        layer_norm_eps (float): Layer Normalization使用的epsilon值。
        dropout_rate (float): 在注意力分数以及最终投影输出中使用的dropout率。
        sd_prob (float): 随机深度率。
        global_pool (str): 表示如何对来自最后一个Transformer模块的表示进行池化。
        pre_logits (bool): 如果设置为True,则不添加分类头。
        num_classes (int): 用于构建最终分类层的类的数量。
    """

    def __init__(
        self,
        projection_dim: int,
        patch_size: int,
        num_patches: int,
        init_values: float,
        mlp_units: typing.List[int],
        sa_ffn_layers: int,
        ca_ffn_layers: int,
        num_heads: int,
        layer_norm_eps: float,
        dropout_rate: float,
        sd_prob: float,
        global_pool: str,
        pre_logits: bool,
        num_classes: int,
        **kwargs,
    ):
        if global_pool not in ["token", "avg"]:
            raise ValueError(
                '接收到的`global_pool`值无效,应为`"token"`或`"avg"`。'
            )

        super().__init__(**kwargs)

        # 负责对输入图像进行补丁化并进行线性投影。
        self.projection = keras.Sequential(
            [
                layers.Conv2D(
                    filters=projection_dim,
                    kernel_size=(patch_size, patch_size),
                    strides=(patch_size, patch_size),
                    padding="VALID",
                    name="conv_projection",
                    kernel_initializer="lecun_normal",
                ),
                layers.Reshape(
                    target_shape=(-1, projection_dim),
                    name="flatten_projection",
                ),
            ],
            name="projection",
        )

        # CLS token和位置嵌入。
        self.cls_token = self.add_weight(
            shape=(1, 1, projection_dim), initializer="zeros"
        )
        self.pos_embed = self.add_weight(
            shape=(1, num_patches, projection_dim), initializer="zeros"
        )

        # 投影dropout。
        self.pos_drop = layers.Dropout(dropout_rate, name="projection_dropout")

        # 随机深度调度。
        dpr = [sd_prob for _ in range(sa_ffn_layers)]

        # 仅对图像补丁嵌入操作的自注意力(SA) Transformer模块。
        self.blocks = [
            LayerScaleBlock(
                projection_dim=projection_dim,
                num_heads=num_heads,
                layer_norm_eps=layer_norm_eps,
                init_values=init_values,
                mlp_units=mlp_units,
                dropout_rate=dropout_rate,
                sd_prob=dpr[i],
                name=f"sa_ffn_block_{i}",
            )
            for i in range(sa_ffn_layers)
        ]

        # 对CLS token和图像补丁嵌入操作的类注意力(CA) Transformer模块。
        self.blocks_token_only = [
            LayerScaleBlockClassAttention(
                projection_dim=projection_dim,
                num_heads=num_heads,
                layer_norm_eps=layer_norm_eps,
                init_values=init_values,
                mlp_units=mlp_units,
                dropout_rate=dropout_rate,
                name=f"ca_ffn_block_{i}",
                sd_prob=0.0,  # 类注意力层中没有随机深度。
            )
            for i in range(ca_ffn_layers)
        ]

        # 预分类层归一化。
        self.norm = layers.LayerNormalization(epsilon=layer_norm_eps, name="head_norm")

        # 用于分类头的表示池化。
        self.global_pool = global_pool

        # 分类头。
        self.pre_logits = pre_logits
        self.num_classes = num_classes
        if not pre_logits:
            self.head = layers.Dense(num_classes, name="classification_head")

    def call(self, x, training=False):
        # 注意这里没有添加CLS token。
        x = self.projection(x)
        x = x + self.pos_embed
        x = self.pos_drop(x)

        # SA+FFN层。
        sa_ffn_attn = {}
        for blk in self.blocks:
            x, attn_scores = blk(x)
            sa_ffn_attn[f"{blk.name}_att"] = attn_scores

        # CA+FFN层。
        ca_ffn_attn = {}
        cls_tokens = ops.tile(self.cls_token, (ops.shape(x)[0], 1, 1))
        for blk in self.blocks_token_only:
            cls_tokens, attn_scores = blk([x, cls_tokens])
            ca_ffn_attn[f"{blk.name}_att"] = attn_scores

        x = ops.concatenate([cls_tokens, x], axis=1)
        x = self.norm(x)

        # 总是返回SA+FFN和CA+FFN层的注意力分数以便于使用。
        if self.global_pool:
            x = (
                ops.reduce_mean(x[:, 1:], axis=1)
                if self.global_pool == "avg"
                else x[:, 0]
            )
        return (
            (x, sa_ffn_attn, ca_ffn_attn)
            if self.pre_logits
            else (self.head(x), sa_ffn_attn, ca_ffn_attn)
        )

拥有这样的SA和CA层分离结构帮助模型更具体地关注潜在目标:

  • 图像补丁之间的模型依赖关系
  • 在CLS令牌中总结图像补丁的信息,可用于手头的任务

现在我们已经定义了CaiT模型,是时候对其进行测试了。我们将首先定义一个模型配置,该配置将传递给我们的CaiT类进行初始化。


定义模型配置

def get_config(
    image_size: int = 224,
    patch_size: int = 16,
    projection_dim: int = 192,
    sa_ffn_layers: int = 24,
    ca_ffn_layers: int = 2,
    num_heads: int = 4,
    mlp_ratio: int = 4,
    layer_norm_eps=1e-6,
    init_values: float = 1e-5,
    dropout_rate: float = 0.0,
    sd_prob: float = 0.0,
    global_pool: str = "token",
    pre_logits: bool = False,
    num_classes: int = 1000,
) -> typing.Dict:
    """CaiT模型的默认配置(cait_xxs24_224)。

    参考:
        https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/cait.py
    """
    config = {}

    # 补丁和投影。
    config["patch_size"] = patch_size
    config["num_patches"] = (image_size // patch_size) ** 2

    # LayerScale.
    config["init_values"] = init_values

    # Dropout和随机深度。
    config["dropout_rate"] = dropout_rate
    config["sd_prob"] = sd_prob

    # 在不同块和层之间共享。
    config["layer_norm_eps"] = layer_norm_eps
    config["projection_dim"] = projection_dim
    config["mlp_units"] = [
        projection_dim * mlp_ratio,
        projection_dim,
    ]

    # 注意力层。
    config["num_heads"] = num_heads
    config["sa_ffn_layers"] = sa_ffn_layers
    config["ca_ffn_layers"] = ca_ffn_layers

    # 表示池化和特定于任务的参数。
    config["global_pool"] = global_pool
    config["pre_logits"] = pre_logits
    config["num_classes"] = num_classes

    return config

如果您已经了解ViT架构,大多数配置变量应该对您来说很熟悉。重点放在sa_ffn_layersca_ffn_layers上,这些会控制SA-Transformer块和CA-Transformer块的数量。您可以轻松修改此get_config()方法以实例化CaiT模型以用于自己的数据集。


模型实例化

image_size = 224
num_channels = 3
batch_size = 2

config = get_config()
cait_xxs24_224 = CaiT(**config)

dummy_inputs = ops.ones((batch_size, image_size, image_size, num_channels))
_ = cait_xxs24_224(dummy_inputs)

我们可以成功地对模型进行推理。但实现的正确性如何呢?有许多方法可以验证它:

  • 在ImageNet-1k验证集上获取模型的性能(假设它已经填充了预训练参数),因为预训练数据集是ImageNet-1k。
  • 在不同的数据集上微调模型。

为了验证这一点,我们将加载另一实例的相同模型,该模型已经填充了预训练参数。有关更多细节,请参考 这个仓库 (由本笔记本的作者开发)。此外,该仓库提供了代码,以验证模型在 ImageNet-1k验证集上的性能以及 微调


加载预训练模型

model_gcs_path = "gs://tfhub-modules/sayakpaul/cait_xxs24_224/1/uncompressed"
pretrained_model = keras.Sequential(
    [keras.layers.TFSMLayer(model_gcs_path, call_endpoint="serving_default")]
)

推理工具

在接下来的几个单元中,我们开发运行推理所需的预处理工具 与预训练模型一起使用。

# 预处理变换包括中心裁剪,并使用ImageNet-1k训练统计数据(均值和标准差)进行归一化。
crop_layer = keras.layers.CenterCrop(image_size, image_size)
norm_layer = keras.layers.Normalization(
    mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
    variance=[(0.229 * 255) ** 2, (0.224 * 255) ** 2, (0.225 * 255) ** 2],
)


def preprocess_image(image, size=image_size):
    image = np.array(image)
    image_resized = ops.expand_dims(image, 0)
    resize_size = int((256 / image_size) * size)
    image_resized = ops.image.resize(
        image_resized, (resize_size, resize_size), interpolation="bicubic"
    )
    image_resized = crop_layer(image_resized)
    return norm_layer(image_resized).numpy()


def load_image_from_url(url):
    image_bytes = io.BytesIO(urlopen(url).read())
    image = PIL.Image.open(image_bytes)
    preprocessed_image = preprocess_image(image)
    return image, preprocessed_image

现在,我们检索ImageNet-1k标签并将其加载为模型我们正在 loading 是在 ImageNet-1k 数据集上进行预训练的。

# ImageNet-1k 类别标签。
imagenet_labels = (
    "https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt"
)
label_path = keras.utils.get_file(origin=imagenet_labels)

with open(label_path, "r") as f:
    lines = f.readlines()
imagenet_labels = [line.rstrip() for line in lines]

加载图像

img_url = "https://i.imgur.com/ErgfLTn.jpg"
image, preprocessed_image = load_image_from_url(img_url)

# https://unsplash.com/photos/Ho93gVTRWW8
plt.imshow(image)
plt.axis("off")
plt.show()

png


获取预测结果

outputs = pretrained_model.predict(preprocessed_image)
logits = outputs["output_1"]
ca_ffn_block_0_att = outputs["output_3_ca_ffn_block_0_att"]
ca_ffn_block_1_att = outputs["output_3_ca_ffn_block_1_att"]

predicted_label = imagenet_labels[int(np.argmax(logits))]
print(predicted_label)
 1/1 ━━━━━━━━━━━━━━━━━━━━ 30s 30s/step
monarch, monarch_butterfly, milkweed_butterfly, Danaus_plexippus

WARNING: 所有在 absl::InitializeLog() 调用之前的日志消息都写入 STDERR
I0000 00:00:1700601113.319904  361514 device_compiler.h:187] 使用 XLA 编译的集群! 该行在进程生命周期中最多记录一次。

现在我们已经获得了预测结果(似乎是预期的),我们可以进一步扩展我们的调查。根据 CaiT 的作者,我们可以调查来自注意力层的注意力分数。这有助于我们深入了解 CaiT 论文中引入的修改。


可视化注意力层

我们首先检查由类注意力层返回的注意力权重的形状。

# (batch_size, nb_attention_heads, num_cls_token, seq_length)
print("类注意力块的注意力分数形状:")
print(ca_ffn_block_0_att.shape)
类注意力块的注意力分数形状:
(1, 4, 1, 197)

形状表示我们获得了每个单独注意力头的注意力权重。它们量化了 CLS token 与自身及其余图像补丁之间的关系信息。

接下来,我们写一个实用程序来:

  • 可视化类注意力层中每个注意力头关注的内容。这有助于我们了解 CaiT 模型中如何诱导 空间-类关系
  • 从第一个类注意力层获取显著性图,帮助理解 CA 层如何聚合图像中感兴趣区域的信息。

这个实用程序参考了原始 CaiT 论文 的图 6 和 7。这也是 这个笔记本 (由本教程作者开发)的一部分。

# 参考:
# https://github.com/facebookresearch/dino/blob/main/visualize_attention.py

patch_size = 16


def get_cls_attention_map(
    attention_scores,
    return_saliency=False,
) -> np.ndarray:
    """
    返回来自特定注意力块的注意力分数。

    参数:
        attention_scores: 来自要可视化的注意力块的注意力分数。
        return_saliency: 如果设置为 True 也返回注意力块的显著性表示的布尔标志。
    """
    w_featmap = preprocessed_image.shape[2] // patch_size
    h_featmap = preprocessed_image.shape[1] // patch_size

    nh = attention_scores.shape[1]  # 注意力头的数量。

    # 取 CLS token 的表示。
    attentions = attention_scores[0, :, 0, 1:].reshape(nh, -1)

    # 重塑注意力分数以类似于小补丁。
    attentions = attentions.reshape(nh, w_featmap, h_featmap)

    if not return_saliency:
        attentions = attentions.transpose((1, 2, 0))

    else:
        attentions = np.mean(attentions, axis=0)
        attentions = (attentions - attentions.min()) / (
            attentions.max() - attentions.min()
        )
        attentions = np.expand_dims(attentions, -1)

    # 将注意力补丁调整为 224x224(224: 14x16)
    attentions = ops.image.resize(
        attentions,
        size=(h_featmap * patch_size, w_featmap * patch_size),
        interpolation="bicubic",
    )

    return attentions

在第一个 CA 层中,我们注意到模型只关注感兴趣区域。

attentions_ca_block_0 = get_cls_attention_map(ca_ffn_block_0_att)


fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(13, 13))
img_count = 0

for i in range(attentions_ca_block_0.shape[-1]):
    if img_count < attentions_ca_block_0.shape[-1]:
        axes[i].imshow(attentions_ca_block_0[:, :, img_count])
        axes[i].title.set_text(f"注意头: {img_count}")
        axes[i].axis("off")
        img_count += 1

fig.tight_layout()
plt.show()

png

在第二个CA层中,模型试图更多地关注包含辨别信号的上下文。

attentions_ca_block_1 = get_cls_attention_map(ca_ffn_block_1_att)


fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(13, 13))
img_count = 0

for i in range(attentions_ca_block_1.shape[-1]):
    if img_count < attentions_ca_block_1.shape[-1]:
        axes[i].imshow(attentions_ca_block_1[:, :, img_count])
        axes[i].title.set_text(f"注意力头: {img_count}")
        axes[i].axis("off")
        img_count += 1

fig.tight_layout()
plt.show()

png

最后,我们获得了给定图像的显著性图。

saliency_attention = get_cls_attention_map(ca_ffn_block_0_att, return_saliency=True)

image = np.array(image)
image_resized = ops.expand_dims(image, 0)
resize_size = int((256 / 224) * image_size)
image_resized = ops.image.resize(
    image_resized, (resize_size, resize_size), interpolation="bicubic"
)
image_resized = crop_layer(image_resized)

plt.imshow(image_resized.numpy().squeeze().astype("int32"))
plt.imshow(saliency_attention.numpy().squeeze(), cmap="cividis", alpha=0.9)
plt.axis("off")

plt.show()

png


结论

在这个笔记本中,我们实现了CaiT模型。它展示了如何在保持预训练数据集不变的情况下,减轻在尝试扩大ViTs深度时所遇到的问题。我希望笔记本中提供的额外可视化能激发社区的兴奋,并促使人们开发出有趣的方法来探讨像ViT这样的模型学到了什么。


致谢

感谢谷歌的ML开发者项目团队提供的Google Cloud Platform支持。