代码示例 / 计算机视觉 / 焦点调制:自注意力的替代品

焦点调制:自注意力的替代品

作者: Aritra Roy Gosthipaty, Ritwik Raha
创建日期: 2023/01/25
最后修改: 2023/02/15
描述: 使用焦点调制网络进行图像分类。

在Colab中查看 GitHub源代码


引言

本教程旨在提供有关焦点调制网络实现的全面指南,如 Yang et al. 中所述。

本教程将以形式化、极简的方法提供焦点调制网络的实现,并探讨其在深度学习领域的潜在应用。

问题陈述

变换器架构(Vaswani et al.)已经成为大多数自然语言处理任务的事实标准,已被应用于计算机视觉领域,例如视觉变换器(Dosovitskiy et al.)。

在变换器中,自注意力(SA)可以说是其成功的关键,使输入依赖的全局交互成为可能,而卷积操作限制了共享内核的局部区域交互。

注意力模块的数学表达如方程1所示。

注意力方程
方程1:注意力的数学方程(来源:Aritra和Ritwik)

其中:

  • Q 是查询
  • K 是键
  • V 是值
  • d_k 是键的维度

自注意力中,查询、键和值都来自输入序列。让我们将自注意力的注意力方程重写为方程2

自注意力方程
方程2:自注意力的数学方程(来源:Aritra和Ritwik)

查看自注意力方程,我们看到它是一个二次方程。因此,随着标记数量的增加,计算时间(成本也会增加)。为了解决这个问题,并使变换器更具可解释性,Yang等人试图用更好的组件替代自注意力模块。

解决方案

Yang等人引入了焦点调制层,以作为自注意力层的无缝替代。该层具有高可解释性,使其成为深度学习从业者的有价值工具。

在本教程中,我们将深入探讨通过在CIFAR-10数据集上训练整个模型来应用该层,并直观地解释该层的性能。

注意:我们尝试将我们的实现与 官方实现 对齐。


设置和导入

我们为本教程使用tensorflow版本2.11.0

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.optimizers.experimental import AdamW
from typing import Optional, Tuple, List
from matplotlib import pyplot as plt
from random import randint

# 设置随机种子以确保可重复性。
tf.keras.utils.set_random_seed(42)

全局配置

我们没有选择这些超参数的强大理由。请随意更改配置并训练模型。

# 数据
TRAIN_SLICE = 40000
BUFFER_SIZE = 2048
BATCH_SIZE = 1024
AUTO = tf.data.AUTOTUNE
INPUT_SHAPE = (32, 32, 3)
IMAGE_SIZE = 48
NUM_CLASSES = 10

# 优化器
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-4

# 训练
EPOCHS = 25

加载并处理CIFAR-10数据集

(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
(x_train, y_train), (x_val, y_val) = (
    (x_train[:TRAIN_SLICE], y_train[:TRAIN_SLICE]),
    (x_train[TRAIN_SLICE:], y_train[TRAIN_SLICE:]),
)
从 https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz 下载数据
170498071/170498071 [==============================] - 30s 0us/步

构建增强

我们使用keras.Sequential API将所有单独的增强步骤组合成一个API。

# Build the `train` augmentation pipeline.
train_aug = keras.Sequential(
    [
        layers.Rescaling(1 / 255.0),
        layers.Resizing(INPUT_SHAPE[0] + 20, INPUT_SHAPE[0] + 20),
        layers.RandomCrop(IMAGE_SIZE, IMAGE_SIZE),
        layers.RandomFlip("horizontal"),
    ],
    name="train_data_augmentation",
)

# Build the `val` and `test` data pipeline.
test_aug = keras.Sequential(
    [
        layers.Rescaling(1 / 255.0),
        layers.Resizing(IMAGE_SIZE, IMAGE_SIZE),
    ],
    name="test_data_augmentation",
)

构建 tf.data 管道

train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_ds = (
    train_ds.map(
        lambda image, label: (train_aug(image), label), num_parallel_calls=AUTO
    )
    .shuffle(BUFFER_SIZE)
    .batch(BATCH_SIZE)
    .prefetch(AUTO)
)

val_ds = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_ds = (
    val_ds.map(lambda image, label: (test_aug(image), label), num_parallel_calls=AUTO)
    .batch(BATCH_SIZE)
    .prefetch(AUTO)
)

test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_ds = (
    test_ds.map(lambda image, label: (test_aug(image), label), num_parallel_calls=AUTO)
    .batch(BATCH_SIZE)
    .prefetch(AUTO)
)
WARNING:tensorflow:From /usr/local/lib/python3.8/dist-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) 已弃用,将在 2023-09-23 之后删除。
更新说明:
将不再假设 Lambda 函数在它们使用的语句中使用,或至少在同一块中使用。 https://github.com/tensorflow/tensorflow/issues/56089

架构

我们在这里暂停一下,快速查看聚焦调制网络的架构。 图 1 展示了每个单独层如何汇编成一个单一模型。这为我们提供了整个架构的鸟瞰图。

模型图
图 1:聚焦调制模型的示意图(来源:Aritra 和 Ritwik)

在接下来的部分中,我们将深入研究这些层。这是我们将遵循的顺序:

  • 补丁嵌入层
  • 聚焦调制块
    • 多层感知器
    • 聚焦调制层
      • 分层上下文化
      • 门控聚合
    • 构建聚焦调制块
  • 构建基础层

为了更好地理解架构,以我们熟悉的格式,我们来看看聚焦调制网络绘制成 Transformer 架构的样子。

图 2 显示了传统 Transformer 架构的编码器层,其中自注意力被聚焦调制层替换。

蓝色 方块代表聚焦调制块。这些块的堆叠构建一个单一的基础层。绿色 方块代表聚焦调制层。

整体架构
图 2:整体架构(来源:Aritra 和 Ritwik)

补丁嵌入层

补丁嵌入层用于将输入图像分块并投影到潜在空间。此层还用作架构中的下采样层。

class PatchEmbed(layers.Layer):
    """图像补丁嵌入层,还充当下采样层。

    Args:
        image_size (Tuple[int]): 输入图像分辨率。
        patch_size (Tuple[int]): 补丁空间分辨率。
        embed_dim (int): 嵌入维度。
    """

    def __init__(
        self,
        image_size: Tuple[int] = (224, 224),
        patch_size: Tuple[int] = (4, 4),
        embed_dim: int = 96,
        **kwargs,
    ):
        super().__init__(**kwargs)
        patch_resolution = [
            image_size[0] // patch_size[0],
            image_size[1] // patch_size[1],
        ]
        self.image_size = image_size
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        self.patch_resolution = patch_resolution
        self.num_patches = patch_resolution[0] * patch_resolution[1]
        self.proj = layers.Conv2D(
            filters=embed_dim, kernel_size=patch_size, strides=patch_size
        )
        self.flatten = layers.Reshape(target_shape=(-1, embed_dim))
        self.norm = keras.layers.LayerNormalization(epsilon=1e-7)

    def call(self, x: tf.Tensor) -> Tuple[tf.Tensor, int, int, int]:
        """将图像分块并转换为标记。

        Args:
            x: 形状为 (B, H, W, C) 的张量

        Returns:
            处理后的张量、投影特征图的高度、投影特征图的宽度、
            投影特征图的通道数的元组。
        """
        # 投影输入。
        x = self.proj(x)

        # 从投影张量获取形状。
        height = tf.shape(x)[1]
        width = tf.shape(x)[2]
        channels = tf.shape(x)[3]

        # B, H, W, C -> B, H*W, C
        x = self.norm(self.flatten(x))

        return x, height, width, channels

聚焦调制块

聚焦调制块可以被视为单个 Transformer 块,其中自注意力(SA)模块被聚焦调制模块替换,如我们在 图 2 中看到的那样。

让我们回想一下聚焦调制块的应有样子。 图3.

焦点调制块
图3:焦点调制块的孤立视图(来源:Aritra和Ritwik)

焦点调制块由以下部分组成: - 多层感知器 - 焦点调制层

多层感知器

def MLP(
    in_features: int,
    hidden_features: Optional[int] = None,
    out_features: Optional[int] = None,
    mlp_drop_rate: float = 0.0,
):
    hidden_features = hidden_features or in_features
    out_features = out_features or in_features

    return keras.Sequential(
        [
            layers.Dense(units=hidden_features, activation=keras.activations.gelu),
            layers.Dense(units=out_features),
            layers.Dropout(rate=mlp_drop_rate),
        ]
    )

焦点调制层

在典型的变换器架构中,对于输入特征图 X in R^{HxWxC} 中的每个视觉标记(查询x_i in R^C,一个 通用编码过程 生成特征表示 y_i in R^C

编码过程包括 交互(例如,点积)和 聚合(例如,加权均值)。

我们将在这里讨论两种类型的编码: - 在 自注意力 中的交互然后聚合 - 在 焦点调制 中的聚合然后交互

自注意力

自注意力表达式
图4:自注意力模块。(来源:Aritra和Ritwik)
自注意力中的聚合和交互
公式3: 自注意力中的聚合和交互(来源:Aritra和Ritwik)

图4所示,查询和键在交互步骤中相互作用,输出注意力分数。接下来是值的加权聚合,称为聚合步骤。

焦点调制

焦点调制模块
图5:焦点调制模块。(来源:Aritra和Ritwik)
焦点调制中的聚合和交互
公式4: 焦点调制中的聚合和交互(来源:Aritra和Ritwik)

图5描述了焦点调制层。 q() 是查询投影函数。它是一个 线性层,将查询投影到潜在空间。 m() 是上下文聚合函数。与自注意力不同,聚合步骤在焦点调制中发生在交互步骤之前。

虽然 q() 很容易理解,但上下文聚合函数 m() 更为复杂。因此,本节将重点关注 m()

上下文聚合
图6:上下文聚合函数 m()。(来源:Aritra和Ritwik)

上下文聚合函数 m()图6中所示的两个部分组成: - 分层上下文化 - 门控聚合

分层上下文化

分层上下文化
图7:分层上下文化(来源:Aritra和Ritwik)

图7中,我们可以看到,输入首先被线性投影。这一线性投影生成 Z^0。其中 Z^0 可以表示如下:

z_not的线性投影
公式5: Z^0 的线性投影(来源:Aritra和Ritwik)

Z^0 然后被传递到一系列深度卷积(DWConv)和 GeLU 层。作者将每个 DWConv 和 GeLU 的区块称为级别,表示为 l。在图6中我们有两个级别。数学上表示为:

调制层的级别
公式6: 调制层的级别(来源:Aritra和Ritwik)

其中 l in {1, ... , L}

最终特征图经过全局平均池化层。这可以表示如下:

平均池化
公式7: 最终特征的平均池化(来源:Aritra和Ritwik)

门控聚合

门控聚合
图8:门控聚合(来源:Aritra和Ritwik)

现在由于分层上下文化步骤我们获得了 L+1 个中间特征图,我们需要一个门控机制,让某些特征通过并禁止其他特征。这可以通过注意力模块实现。 稍后在教程中,我们将可视化这些门,以更好地理解它们的有用性。

首先,我们构建聚合的权重。这里我们对输入特征图应用一个 线性层,将其投影到 L+1 维度。

| 门 | | :–: | | 公式8: 门(来源:Aritra和Ritwik) | Next we perform the weighted aggregation over the contexts.

z out
方程 9: 最终特征图 (来源: Aritra 和 Ritwik)

To enable communication across different channels, we use another linear layer h() to obtain the modulator

Modulator
方程 10: 调制器 (来源: Aritra 和 Ritwik)

To sum up the Focal Modulation layer we have:

Focal Modulation Layer
方程 11: 焦点调制层 (来源: Aritra 和 Ritwik)
class FocalModulationLayer(layers.Layer):
    """The Focal Modulation layer includes query projection & context aggregation.

    Args:
        dim (int): Projection dimension.
        focal_window (int): Window size for focal modulation.
        focal_level (int): The current focal level.
        focal_factor (int): Factor of focal modulation.
        proj_drop_rate (float): Rate of dropout.
    """

    def __init__(
        self,
        dim: int,
        focal_window: int,
        focal_level: int,
        focal_factor: int = 2,
        proj_drop_rate: float = 0.0,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.dim = dim
        self.focal_window = focal_window
        self.focal_level = focal_level
        self.focal_factor = focal_factor
        self.proj_drop_rate = proj_drop_rate

        # 使用线性层将输入特征投影到新的特征空间。
        # 注意使用的 `units`。我们将一次性投影输入特征,并将投影拆分为查询、上下文和门。
        self.initial_proj = layers.Dense(
            units=(2 * self.dim) + (self.focal_level + 1),
            use_bias=True,
        )
        self.focal_layers = list()
        self.kernel_sizes = list()
        for idx in range(self.focal_level):
            kernel_size = (self.focal_factor * idx) + self.focal_window
            depth_gelu_block = keras.Sequential(
                [
                    layers.ZeroPadding2D(padding=(kernel_size // 2, kernel_size // 2)),
                    layers.Conv2D(
                        filters=self.dim,
                        kernel_size=kernel_size,
                        activation=keras.activations.gelu,
                        groups=self.dim,
                        use_bias=False,
                    ),
                ]
            )
            self.focal_layers.append(depth_gelu_block)
            self.kernel_sizes.append(kernel_size)
        self.activation = keras.activations.gelu
        self.gap = layers.GlobalAveragePooling2D(keepdims=True)
        self.modulator_proj = layers.Conv2D(
            filters=self.dim,
            kernel_size=(1, 1),
            use_bias=True,
        )
        self.proj = layers.Dense(units=self.dim)
        self.proj_drop = layers.Dropout(self.proj_drop_rate)

    def call(self, x: tf.Tensor, training: Optional[bool] = None) -> tf.Tensor:
        """Forward pass of the layer.

        Args:
            x: Tensor of shape (B, H, W, C)
        """
        # 将线性投影应用于输入特征图
        x_proj = self.initial_proj(x)

        # 将投影后的 x 拆分为查询、上下文和门
        query, context, self.gates = tf.split(
            value=x_proj,
            num_or_size_splits=[self.dim, self.dim, self.focal_level + 1],
            axis=-1,
        )

        # 上下文聚合
        context = self.focal_layers[0](context)
        context_all = context * self.gates[..., 0:1]
        for idx in range(1, self.focal_level):
            context = self.focal_layers[idx](context)
            context_all += context * self.gates[..., idx : idx + 1]

        # 构建全局上下文
        context_global = self.activation(self.gap(context))
        context_all += context_global * self.gates[..., self.focal_level :]

        # 焦点调制
        self.modulator = self.modulator_proj(context_all)
        x_output = query * self.modulator

        # 投影输出并应用 dropout
        x_output = self.proj(x_output)
        x_output = self.proj_drop(x_output)

        return x_output

The Focal Modulation block

Finally, we have all the components we need to build the Focal Modulation block. Here we take the MLP and Focal Modulation layer together and build the Focal Modulation block.

class FocalModulationBlock(layers.Layer):
    """组合 FFN 和焦点调制层。

    Args:
        dim (int): 输入通道数量。
        input_resolution (Tuple[int]): 输入分辨率。
        mlp_ratio (float): mlp 隐藏维度与嵌入维度的比率。
        drop (float): 丢弃率。
        drop_path (float): 随机深度率。
        focal_level (int): 焦点级别数量。
        focal_window (int): 第一级焦点窗口大小。
    """

    def __init__(
        self,
        dim: int,
        input_resolution: Tuple[int],
        mlp_ratio: float = 4.0,
        drop: float = 0.0,
        drop_path: float = 0.0,
        focal_level: int = 1,
        focal_window: int = 3,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.dim = dim
        self.input_resolution = input_resolution
        self.mlp_ratio = mlp_ratio
        self.focal_level = focal_level
        self.focal_window = focal_window
        self.norm = layers.LayerNormalization(epsilon=1e-5)
        self.modulation = FocalModulationLayer(
            dim=self.dim,
            focal_window=self.focal_window,
            focal_level=self.focal_level,
            proj_drop_rate=drop,
        )
        mlp_hidden_dim = int(self.dim * self.mlp_ratio)
        self.mlp = MLP(
            in_features=self.dim,
            hidden_features=mlp_hidden_dim,
            mlp_drop_rate=drop,
        )

    def call(self, x: tf.Tensor, height: int, width: int, channels: int) -> tf.Tensor:
        """通过焦点调制块处理输入张量。

        Args:
            x (tf.Tensor): 形状为 (B, L, C) 的输入
            height (int): 特征图的高度
            width (int): 特征图的宽度
            channels (int): 特征图的通道数

        Returns:
            处理后的张量。
        """
        shortcut = x

        # 焦点调制
        x = tf.reshape(x, shape=(-1, height, width, channels))
        x = self.modulation(x)
        x = tf.reshape(x, shape=(-1, height * width, channels))

        # FFN
        x = shortcut + x
        x = x + self.mlp(self.norm(x))
        return x

基础层

基础层由一组焦点调制模块组成。这在图9中 illustrated.

基础层
图9:基础层,焦点调制模块的集合。(来源:Aritra和Ritwik)

请注意,在图9中,有多个焦点调制模块用Nx表示。这表明基础层是焦点调制模块的集合。

class BasicLayer(layers.Layer):
    """焦点调制模块的集合。

    参数:
        dim (int): 模型的维度。
        out_dim (int): Patch嵌入层使用的维度。
        input_resolution (Tuple[int]): 输入图像分辨率。
        depth (int): 焦点调制模块的数量。
        mlp_ratio (float): mlp隐藏维度与嵌入维度的比率。
        drop (float): 随机失活率。
        downsample (tf.keras.layers.Layer): 层末尾的下采样层。
        focal_level (int): 当前焦点级别。
        focal_window (int): 使用的焦点窗口。
    """

    def __init__(
        self,
        dim: int,
        out_dim: int,
        input_resolution: Tuple[int],
        depth: int,
        mlp_ratio: float = 4.0,
        drop: float = 0.0,
        downsample=None,
        focal_level: int = 1,
        focal_window: int = 1,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.dim = dim
        self.input_resolution = input_resolution
        self.depth = depth
        self.blocks = [
            FocalModulationBlock(
                dim=dim,
                input_resolution=input_resolution,
                mlp_ratio=mlp_ratio,
                drop=drop,
                focal_level=focal_level,
                focal_window=focal_window,
            )
            for i in range(self.depth)
        ]

        # 层末尾的下采样层
        if downsample is not None:
            self.downsample = downsample(
                image_size=input_resolution,
                patch_size=(2, 2),
                embed_dim=out_dim,
            )
        else:
            self.downsample = None

    def call(
        self, x: tf.Tensor, height: int, width: int, channels: int
    ) -> Tuple[tf.Tensor, int, int, int]:
        """层的前向传播。

        参数:
            x (tf.Tensor): 形状为(B, L, C)的张量
            height (int): 特征图的高度
            width (int): 特征图的宽度
            channels (int): 特征图的嵌入维度

        返回:
            处理后的张量、改变的高度、宽度和
            张量的维度的元组。
        """
        # 应用焦点调制模块
        for block in self.blocks:
            x = block(x, height, width, channels)

        # 除最后一个基础层外,所有层的末尾都有
        # 下采样。
        if self.downsample is not None:
            x = tf.reshape(x, shape=(-1, height, width, channels))
            x, height_o, width_o, channels_o = self.downsample(x)
        else:
            height_o, width_o, channels_o = height, width, channels

        return x, height_o, width_o, channels_o

焦点调制网络模型

这是将所有内容连接在一起的模型。 它由一组基础层和一个分类头组成。 有关此结构的回顾,请参见图1

class FocalModulationNetwork(keras.Model):
    """聚焦调制网络。

    参数:
        image_size (Tuple[int]): 使用的图像空间大小。
        patch_size (Tuple[int]): 每个补丁的补丁大小。
        num_classes (int): 分类使用的类别数量。
        embed_dim (int): 补丁嵌入维度。
        depths (List[int]): 每个聚焦变换器块的深度。
        mlp_ratio (float): MLP中间层扩展的比例。
        drop_rate (float): FM和MLP层的丢弃率。
        focal_levels (list): 在各个阶段的聚焦级别数量。
            注意这不包括最细粒度级别。
        focal_windows (list): 在各个阶段的聚焦窗口大小。
    """

    def __init__(
        self,
        image_size: Tuple[int] = (48, 48),
        patch_size: Tuple[int] = (4, 4),
        num_classes: int = 10,
        embed_dim: int = 256,
        depths: List[int] = [2, 3, 2],
        mlp_ratio: float = 4.0,
        drop_rate: float = 0.1,
        focal_levels=[2, 2, 2],
        focal_windows=[3, 3, 3],
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.num_layers = len(depths)
        embed_dim = [embed_dim * (2**i) for i in range(self.num_layers)]
        self.num_classes = num_classes
        self.embed_dim = embed_dim
        self.num_features = embed_dim[-1]
        self.mlp_ratio = mlp_ratio
        self.patch_embed = PatchEmbed(
            image_size=image_size,
            patch_size=patch_size,
            embed_dim=embed_dim[0],
        )
        num_patches = self.patch_embed.num_patches
        patches_resolution = self.patch_embed.patch_resolution
        self.patches_resolution = patches_resolution
        self.pos_drop = layers.Dropout(drop_rate)
        self.basic_layers = list()
        for i_layer in range(self.num_layers):
            layer = BasicLayer(
                dim=embed_dim[i_layer],
                out_dim=embed_dim[i_layer + 1]
                if (i_layer < self.num_layers - 1)
                else None,
                input_resolution=(
                    patches_resolution[0] // (2**i_layer),
                    patches_resolution[1] // (2**i_layer),
                ),
                depth=depths[i_layer],
                mlp_ratio=self.mlp_ratio,
                drop=drop_rate,
                downsample=PatchEmbed if (i_layer < self.num_layers - 1) else None,
                focal_level=focal_levels[i_layer],
                focal_window=focal_windows[i_layer],
            )
            self.basic_layers.append(layer)
        self.norm = keras.layers.LayerNormalization(epsilon=1e-7)
        self.avgpool = layers.GlobalAveragePooling1D()
        self.flatten = layers.Flatten()
        self.head = layers.Dense(self.num_classes, activation="softmax")

    def call(self, x: tf.Tensor) -> tf.Tensor:
        """层的前向传播。

        参数:
            x: 形状为 (B, H, W, C) 的张量

        返回:
            逻辑值。
        """
        # 对输入图像进行补丁嵌入。
        x, height, width, channels = self.patch_embed(x)
        x = self.pos_drop(x)

        for idx, layer in enumerate(self.basic_layers):
            x, height, width, channels = layer(x, height, width, channels)

        x = self.norm(x)
        x = self.avgpool(x)
        x = self.flatten(x)
        x = self.head(x)
        return x

训练模型

现在所有组件都到位,架构也已建立,我们准备好将其用于实际应用。

在本节中,我们将在CIFAR-10数据集上训练我们的焦点调制模型。

可视化回调

焦点调制网络的一个关键特征是显式输入依赖性。这意味着调制器是通过观察目标位置周围的局部特征来计算的,因此它依赖于输入。简单来说,这样使得解释变得容易。我们可以简单地将门控值和原始图像并排放置,以查看门控机制的工作原理。

论文的作者可视化门和调制器,以便关注焦点调制层的可解释性。下面是一个可视化回调,它在模型训练时显示模型中特定层的门和调制器。

稍后我们会注意到,随着模型的训练,可视化效果越来越好。

门似乎选择性地允许输入图像的某些方面通过,同时轻柔地忽略其他方面,最终提高了分类准确性。

def display_grid(
    test_images: tf.Tensor,
    gates: tf.Tensor,
    modulator: tf.Tensor,
):
    """显示带有门控和调制器叠加的图像。

    参数:
        test_images (tf.Tensor): 一批测试图像。
        gates (tf.Tensor): 焦点调制层的门控。
        modulator (tf.Tensor): 焦点调制层的调制器。
    """
    fig, ax = plt.subplots(nrows=1, ncols=5, figsize=(25, 5))

    # 从批次中随机抽取一张图像。
    index = randint(0, BATCH_SIZE - 1)
    orig_image = test_images[index]
    gate_image = gates[index]
    modulator_image = modulator[index]

    # 原始图像
    ax[0].imshow(orig_image)
    ax[0].set_title("Original:")
    ax[0].axis("off")

    for index in range(1, 5):
        img = ax[index].imshow(orig_image)
        if index != 4:
            overlay_image = gate_image[..., index - 1]
            title = f"G {index}:"
        else:
            overlay_image = tf.norm(modulator_image, ord=2, axis=-1)
            title = f"MOD:"

        ax[index].imshow(
            overlay_image, cmap="inferno", alpha=0.6, extent=img.get_extent()
        )
        ax[index].set_title(title)
        ax[index].axis("off")

    plt.axis("off")
    plt.show()
    plt.close()

训练监控

# 采用一批测试输入来测量模型的进展。
test_images, test_labels = next(iter(test_ds))
upsampler = tf.keras.layers.UpSampling2D(
    size=(4, 4),
    interpolation="bilinear",
)


class TrainMonitor(keras.callbacks.Callback):
    def __init__(self, epoch_interval=None):
        self.epoch_interval = epoch_interval

    def on_epoch_end(self, epoch, logs=None):
        if self.epoch_interval and epoch % self.epoch_interval == 0:
            _ = self.model(test_images)

            # 获取中间层进行可视化
            gates = self.model.basic_layers[1].blocks[-1].modulation.gates
            gates = upsampler(gates)
            modulator = self.model.basic_layers[1].blocks[-1].modulation.modulator
            modulator = upsampler(modulator)

            # 显示门和调制器的网格。
            display_grid(test_images=test_images, gates=gates, modulator=modulator)

学习率调度器

# Some code is taken from:
# https://www.kaggle.com/ashusma/training-rfcx-tensorflow-tpu-effnet-b2.
class WarmUpCosine(keras.optimizers.schedules.LearningRateSchedule):
    def __init__(
        self, learning_rate_base, total_steps, warmup_learning_rate, warmup_steps
    ):
        super().__init__()
        self.learning_rate_base = learning_rate_base
        self.total_steps = total_steps
        self.warmup_learning_rate = warmup_learning_rate
        self.warmup_steps = warmup_steps
        self.pi = tf.constant(np.pi)

    def __call__(self, step):
        if self.total_steps < self.warmup_steps:
            raise ValueError("Total_steps must be larger or equal to warmup_steps.")
        cos_annealed_lr = tf.cos(
            self.pi
            * (tf.cast(step, tf.float32) - self.warmup_steps)
            / float(self.total_steps - self.warmup_steps)
        )
        learning_rate = 0.5 * self.learning_rate_base * (1 + cos_annealed_lr)
        if self.warmup_steps > 0:
            if self.learning_rate_base < self.warmup_learning_rate:
                raise ValueError(
                    "Learning_rate_base must be larger or equal to "
                    "warmup_learning_rate."
                )
            slope = (
                self.learning_rate_base - self.warmup_learning_rate
            ) / self.warmup_steps
            warmup_rate = slope * tf.cast(step, tf.float32) + self.warmup_learning_rate
            learning_rate = tf.where(
                step < self.warmup_steps, warmup_rate, learning_rate
            )
        return tf.where(
            step > self.total_steps, 0.0, learning_rate, name="learning_rate"
        )


total_steps = int((len(x_train) / BATCH_SIZE) * EPOCHS)
warmup_epoch_percentage = 0.15
warmup_steps = int(total_steps * warmup_epoch_percentage)
scheduled_lrs = WarmUpCosine(
    learning_rate_base=LEARNING_RATE,
    total_steps=total_steps,
    warmup_learning_rate=0.0,
    warmup_steps=warmup_steps,
)

初始化、编译和训练模型

focal_mod_net = FocalModulationNetwork()
optimizer = AdamW(learning_rate=scheduled_lrs, weight_decay=WEIGHT_DECAY)

# 编译并训练模型。
focal_mod_net.compile(
    optimizer=optimizer,
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"],
)
history = focal_mod_net.fit(
    train_ds,
    epochs=EPOCHS,
    validation_data=val_ds,
    callbacks=[TrainMonitor(epoch_interval=10)],
)
第 1 轮 / 25
40/40 [==============================] - ETA: 0s - loss: 2.3925 - accuracy: 0.1401

png

40/40 [==============================] - 57s 724ms/step - loss: 2.3925 - accuracy: 0.1401 - val_loss: 2.2182 - val_accuracy: 0.1768
第 2 轮 / 25
40/40 [==============================] - 20s 483ms/step - loss: 2.0790 - accuracy: 0.2261 - val_loss: 2.2933 - val_accuracy: 0.1795
第 3 轮 / 25
40/40 [==============================] - 19s 479ms/step - loss: 2.0130 - accuracy: 0.2585 - val_loss: 2.6833 - val_accuracy: 0.2022
第 4 轮 / 25
40/40 [==============================] - 21s 507ms/step - loss: 1.8270 - accuracy: 0.3315 - val_loss: 1.9127 - val_accuracy: 0.3215
第 5 轮 / 25
40/40 [==============================] - 19s 475ms/step - loss: 1.6037 - accuracy: 0.4173 - val_loss: 1.7226 - val_accuracy: 0.3938
第 6 轮 / 25
40/40 [==============================] - 19s 476ms/step - loss: 1.4758 - accuracy: 0.4658 - val_loss: 1.5097 - val_accuracy: 0.4733
第 7 轮 / 25
40/40 [==============================] - 19s 476ms/step - loss: 1.3677 - accuracy: 0.5075 - val_loss: 1.4630 - val_accuracy: 0.4986
第 8 轮 / 25
40/40 [==============================] - 21s 508ms/step - loss: 1.2599 - accuracy: 0.5490 - val_loss: 1.2908 - val_accuracy: 0.5492
第 9 轮 / 25
40/40 [==============================] - 19s 478ms/step - loss: 1.1689 - accuracy: 0.5818 - val_loss: 1.2750 - val_accuracy: 0.5518
第 10 轮 / 25
40/40 [==============================] - 19s 476ms/step - loss: 1.0843 - accuracy: 0.6140 - val_loss: 1.1444 - val_accuracy: 0.6002
第 11 轮 / 25
39/40 [============================>.] - ETA: 0s - loss: 1.0040 - accuracy: 0.6453

png

40/40 [==============================] - 20s 489ms/step - loss: 1.0041 - accuracy: 0.6452 - val_loss: 1.1765 - val_accuracy: 0.5939
第 12 轮 / 25
40/40 [==============================] - 20s 480ms/step - loss: 0.9401 - accuracy: 0.6701 - val_loss: 1.1276 - val_accuracy: 0.6181
第 13 轮 / 25
40/40 [==============================] - 19s 480ms/step - loss: 0.8787 - accuracy: 0.6910 - val_loss: 0.9990 - val_accuracy: 0.6547
第 14 轮 / 25
40/40 [==============================] - 19s 479ms/step - loss: 0.8198 - accuracy: 0.7122 - val_loss: 1.0074 - val_accuracy: 0.6562
第 15 轮 / 25
40/40 [==============================] - 19s 480ms/step - loss: 0.7831 - accuracy: 0.7275 - val_loss: 0.9739 - val_accuracy: 0.6686
第 16 轮 / 25
40/40 [==============================] - 19s 478ms/step - loss: 0.7358 - accuracy: 0.7428 - val_loss: 0.9578 - val_accuracy: 0.6753
第 17 轮 / 25
40/40 [==============================] - 19s 478ms/step - loss: 0.7018 - accuracy: 0.7557 - val_loss: 0.9414 - val_accuracy: 0.6789
第 18 轮 / 25
40/40 [==============================] - 20s 480ms/step - loss: 0.6678 - accuracy: 0.7678 - val_loss: 0.9492 - val_accuracy: 0.6771
第 19 轮 / 25
40/40 [==============================] - 19s 476ms/step - loss: 0.6423 - accuracy: 0.7783 - val_loss: 0.9422 - val_accuracy: 0.6832
第 20 轮 / 25
40/40 [==============================] - 19s 479ms/step - loss: 0.6202 - accuracy: 0.7868 - val_loss: 0.9324 - val_accuracy: 0.6860
第 21 轮 / 25
40/40 [==============================] - ETA: 0s - loss: 0.6005 - accuracy: 0.7938

png

40/40 [==============================] - 20s 488ms/step - loss: 0.6005 - accuracy: 0.7938 - val_loss: 0.9326 - val_accuracy: 0.6880
第 22 轮 / 25
40/40 [==============================] - 19s 478ms/step - loss: 0.5937 - accuracy: 0.7970 - val_loss: 0.9339 - val_accuracy: 0.6875
第 23 轮 / 25
40/40 [==============================] - 19s 478ms/step - loss: 0.5899 - accuracy: 0.7984 - val_loss: 0.9294 - val_accuracy: 0.6894
第 24 轮 / 25
40/40 [==============================] - 19s 478ms/step - loss: 0.5840 - accuracy: 0.8012 - val_loss: 0.9315 - val_accuracy: 0.6881
第 25 轮 / 25
40/40 [==============================] - 19s 478ms/step - loss: 0.5853 - accuracy: 0.7997 - val_loss: 0.9315 - val_accuracy: 0.6880

绘制损失和准确度

plt.plot(history.history["loss"], label="损失")
plt.plot(history.history["val_loss"], label="验证损失")
plt.legend()
plt.show()

plt.plot(history.history["accuracy"], label="准确率")
plt.plot(history.history["val_accuracy"], label="验证准确率")
plt.legend()
plt.show()

png

png


测试可视化

让我们在一些测试图像上测试我们的模型,看看门控的样子。

test_images, test_labels = next(iter(test_ds))
_ = focal_mod_net(test_images)

# 取中间层进行可视化
gates = focal_mod_net.basic_layers[1].blocks[-1].modulation.gates
gates = upsampler(gates)
modulator = focal_mod_net.basic_layers[1].blocks[-1].modulation.modulator
modulator = upsampler(modulator)

# 绘制测试图像,并叠加门控和调制器。
for row in range(5):
    display_grid(
        test_images=test_images,
        gates=gates,
        modulator=modulator,
    )

png

png

png

png

png


结论

所提出的架构,即焦点调制网络架构,是一种允许图像不同部分以取决于图像本身的方式相互作用的机制。 它的工作原理是首先收集图像每个部分(“查询标记”)周围的不同层次的上下文信息,然后使用门控决定哪些上下文信息最相关,最后以简单但有效的方式组合所选信息。

这被视为取代变压器架构中的自注意力机制。使这项研究显著的关键特征不是无注意力网络的概念,而是引入了一种同样强大的可解释架构。

作者还提到,他们创建了一系列焦点调制网络(FocalNets),这些网络显著超越了自注意力的同行,且参数和预训练数据的数量更少。

FocalNets架构有潜力提供令人印象深刻的结果,并且提供了简单的实现。它的良好表现和易用性使其成为研究人员在自己项目中探索自注意力的有吸引力的替代方案。它有可能在不久的将来被深度学习社区广泛采用。


致谢

我们要感谢 PyImageSearch 为我们提供了Colab Pro账户,感谢 JarvisLabs.ai 提供GPU积分,感谢微软研究院提供他们论文的官方实现。 我们还要特别感谢论文的第一作者 Jianwei Yang,他对本教程进行了广泛的审阅。