代码示例 / 计算机视觉 / 没有注意力机制的视觉变换器

没有注意力机制的视觉变换器

作者: Aritra Roy Gosthipaty, Ritwik Raha, Shivalika Singh
创建日期: 2022/02/24
最后修改日期: 2022/10/15
描述: ShiftViT 的最小实现。

在 Colab 中查看 GitHub 源码


介绍

视觉变换器 (ViTs) 在变换器与计算机视觉 (CV) 的交叉领域引发了一波研究热潮。

由于变换器模块中的多头自注意力机制,ViTs 可以同时建模长程和短程依赖关系。许多研究人员认为 ViTs 的成功纯粹归因于注意力层,他们很少考虑 ViT 模型的其他部分。

在学术论文 当平移操作遇见视觉变换器:注意力机制的一种极其简单的替代方案 中,作者提出用一种无参数操作代替注意力操作,从而揭示 ViTs 成功的奥秘。他们用平移操作替代了注意力操作。

在本示例中,我们最小化实现了该论文,并与作者的 官方实现 进行了密切对齐。

本示例需要 TensorFlow 2.9 或更高版本,以及 TensorFlow Addons,可以使用以下命令安装:

!pip install -qq -U tensorflow-addons

设置和导入

import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa

import pathlib
import glob

# 为了重现性设置随机种子
SEED = 42
keras.utils.set_random_seed(SEED)

超参数

这些是我们为实验选择的超参数。 请随意调整它们。

class Config(object):
    # 数据
    batch_size = 256
    buffer_size = batch_size * 2
    input_shape = (32, 32, 3)
    num_classes = 10

    # 增强
    image_size = 48

    # 架构
    patch_size = 4
    projected_dim = 96
    num_shift_blocks_per_stages = [2, 4, 8, 2]
    epsilon = 1e-5
    stochastic_depth_rate = 0.2
    mlp_dropout_rate = 0.2
    num_div = 12
    shift_pixel = 1
    mlp_expand_ratio = 2

    # 优化器
    lr_start = 1e-5
    lr_max = 1e-3
    weight_decay = 1e-4

    # 训练
    epochs = 100

    # 推理
    label_map = {
        0: "飞机",
        1: "汽车",
        2: "鸟",
        3: "猫",
        4: "鹿",
        5: "狗",
        6: "青蛙",
        7: "马",
        8: "船",
        9: "卡车",
    }
    tf_ds_batch_size = 20


config = Config()

加载 CIFAR-10 数据集

我们使用 CIFAR-10 数据集进行实验。

(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
(x_train, y_train), (x_val, y_val) = (
    (x_train[:40000], y_train[:40000]),
    (x_train[40000:], y_train[40000:]),
)
print(f"训练样本: {len(x_train)}")
print(f"验证样本: {len(x_val)}")
print(f"测试样本: {len(x_test)}")

AUTO = tf.data.AUTOTUNE
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_ds = train_ds.shuffle(config.buffer_size).batch(config.batch_size).prefetch(AUTO)

val_ds = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_ds = val_ds.batch(config.batch_size).prefetch(AUTO)

test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_ds = test_ds.batch(config.batch_size).prefetch(AUTO)
从 https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz 下载数据
170498071/170498071 [==============================] - 3s 0us/step
训练样本: 40000
验证样本: 10000
测试样本: 10000

数据增强

增强管道包括:

  • 重新缩放
  • 调整大小
  • 随机裁剪
  • 随机水平翻转

注意: 图像数据增强层在推理时不应用 数据变换。这意味着 当这些层在 training=False 时,它们的行为不同。有关更多详细信息,请参阅 文档

def get_augmentation_model():
    """构建数据增强模型。"""
    data_augmentation = keras.Sequential(
        [
            layers.Resizing(config.input_shape[0] + 20, config.input_shape[0] + 20),
            layers.RandomCrop(config.image_size, config.image_size),
            layers.RandomFlip("horizontal"),
            layers.Rescaling(1 / 255.0),
        ]
    )
    return data_augmentation

ShiftViT架构

在本节中,我们构建了在 ShiftViT论文中提出的架构。

ShiftViT Architecture
图1:ShiftViT的整个架构。
来源

如图1所示,该架构的灵感来源于 Swin Transformer: 使用移动窗口的分层视觉变换器。 在这里,作者提出了一个具有4个阶段的模块化架构。每个阶段在其 自己的空间大小上工作,创建了一个分层架构。

输入图像大小为HxWx3,被分割成大小为4x4的不重叠块。 这通过patchify层完成,结果是特征大小为484x4x3)的单独标记。每个阶段包括两个部分:

  1. 嵌入生成
  2. 堆叠的移位块

我们在接下来的内容中详细讨论阶段和模块。

注意:与官方实现相比,我们重新构造了一些关键组件,以更好地适应Keras API。

ShiftViT块

ShiftViT block
图2:从模型到移位块。

ShiftViT架构中的每个阶段都包含一个移位块,如图2所示。

Shift Vit Block
图3:Shift ViT块。来源

如图3所示,移位块包括以下内容:

  1. 移位操作
  2. 线性归一化
  3. MLP层

MLP块

MLP块旨在成为密集连接层的堆栈。

class MLP(layers.Layer):
    """获取每个移位块的MLP层。

    参数:
        mlp_expand_ratio (int):第一个特征图扩展的比率。
        mlp_dropout_rate (float):丢弃的比率。
    """

    def __init__(self, mlp_expand_ratio, mlp_dropout_rate, **kwargs):
        super().__init__(**kwargs)
        self.mlp_expand_ratio = mlp_expand_ratio
        self.mlp_dropout_rate = mlp_dropout_rate

    def build(self, input_shape):
        input_channels = input_shape[-1]
        initial_filters = int(self.mlp_expand_ratio * input_channels)

        self.mlp = keras.Sequential(
            [
                layers.Dense(
                    units=initial_filters,
                    activation=tf.nn.gelu,
                ),
                layers.Dropout(rate=self.mlp_dropout_rate),
                layers.Dense(units=input_channels),
                layers.Dropout(rate=self.mlp_dropout_rate),
            ]
        )

    def call(self, x):
        x = self.mlp(x)
        return x

DropPath层

随机深度是一种正则化技术,它随机丢弃一组 层。在推理过程中,层保持不变。它与Dropout非常相似,但它作用于一块层而不是单个节点。

class DropPath(layers.Layer):
    """Drop Path,也称为随机深度层。

    参考:
        - https://keras.io/examples/vision/cct/#stochastic-depth-for-regularization
        - github.com:rwightman/pytorch-image-models
    """

    def __init__(self, drop_path_prob, **kwargs):
        super().__init__(**kwargs)
        self.drop_path_prob = drop_path_prob

    def call(self, x, training=False):
        if training:
            keep_prob = 1 - self.drop_path_prob
            shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
            random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
            random_tensor = tf.floor(random_tensor)
            return (x / keep_prob) * random_tensor
        return x

本文中最重要的操作是移位操作。在本节中, 我们描述移位操作,并与作者提供的原始实现进行比较。

假设一个通用的特征图具有形状[N, H, W, C]。在这里,我们选择一个 num_div参数,该参数决定通道的划分大小。前4个划分 向左、右、上和下方向移位(1像素)。其余的分割保持不变。经过部分移位后,移位的通道被填充,溢出的 像素被切掉。这完成了部分移位操作。

在原始实现中,代码大致如下:

out[:, g * 0:g * 1, :, :-1] = x[:, g * 0:g * 1, :, 1:]  # 向左移动
out[:, g * 1:g * 2, :, 1:] = x[:, g * 1:g * 2, :, :-1]  # 向右移动
out[:, g * 2:g * 3, :-1, :] = x[:, g * 2:g * 3, 1:, :]  # 向上移动
out[:, g * 3:g * 4, 1:, :] = x[:, g * 3:g * 4, :-1, :]  # 向下移动

out[:, g * 4:, :, :] = x[:, g * 4:, :, :]  # 不移动

中间训练过程。因此,我们采用以下程序:

  1. 使用 num_div 参数分割通道。
  2. 选择前四个分割并向相应方向移动和填充它们。
  3. 在移动和填充后,我们将通道连接回去。
Manim 渲染的移位操作动画
图4:TensorFlow 风格的移位

整个过程在图4中进行了说明。

class ShiftViTBlock(layers.Layer):
    """一个单位 ShiftViT 块

    参数:
        shift_pixel (int): 移动的像素数。默认为1。
        mlp_expand_ratio (int): MLP特征扩展的比例。
            默认为2。
        mlp_dropout_rate (float): MLP中使用的丢弃率。
        num_div (int): 特征图通道的分割数。
            总共将移动 4/num_div 的通道。默认为12。
        epsilon (float): Epsilon 常量。
        drop_path_prob (float): 丢弃路径的丢弃概率。
    """

    def __init__(
        self,
        epsilon,
        drop_path_prob,
        mlp_dropout_rate,
        num_div=12,
        shift_pixel=1,
        mlp_expand_ratio=2,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.shift_pixel = shift_pixel
        self.mlp_expand_ratio = mlp_expand_ratio
        self.mlp_dropout_rate = mlp_dropout_rate
        self.num_div = num_div
        self.epsilon = epsilon
        self.drop_path_prob = drop_path_prob

    def build(self, input_shape):
        self.H = input_shape[1]
        self.W = input_shape[2]
        self.C = input_shape[3]
        self.layer_norm = layers.LayerNormalization(epsilon=self.epsilon)
        self.drop_path = (
            DropPath(drop_path_prob=self.drop_path_prob)
            if self.drop_path_prob > 0.0
            else layers.Activation("linear")
        )
        self.mlp = MLP(
            mlp_expand_ratio=self.mlp_expand_ratio,
            mlp_dropout_rate=self.mlp_dropout_rate,
        )

    def get_shift_pad(self, x, mode):
        """根据所选的模式移动通道。"""
        if mode == "left":
            offset_height = 0
            offset_width = 0
            target_height = 0
            target_width = self.shift_pixel
        elif mode == "right":
            offset_height = 0
            offset_width = self.shift_pixel
            target_height = 0
            target_width = self.shift_pixel
        elif mode == "up":
            offset_height = 0
            offset_width = 0
            target_height = self.shift_pixel
            target_width = 0
        else:
            offset_height = self.shift_pixel
            offset_width = 0
            target_height = self.shift_pixel
            target_width = 0
        crop = tf.image.crop_to_bounding_box(
            x,
            offset_height=offset_height,
            offset_width=offset_width,
            target_height=self.H - target_height,
            target_width=self.W - target_width,
        )
        shift_pad = tf.image.pad_to_bounding_box(
            crop,
            offset_height=offset_height,
            offset_width=offset_width,
            target_height=self.H,
            target_width=self.W,
        )
        return shift_pad

    def call(self, x, training=False):
        # 分割特征图
        x_splits = tf.split(x, num_or_size_splits=self.C // self.num_div, axis=-1)

        # 移动特征图
        x_splits[0] = self.get_shift_pad(x_splits[0], mode="left")
        x_splits[1] = self.get_shift_pad(x_splits[1], mode="right")
        x_splits[2] = self.get_shift_pad(x_splits[2], mode="up")
        x_splits[3] = self.get_shift_pad(x_splits[3], mode="down")

        # 连接移动和未移动的特征图
        x = tf.concat(x_splits, axis=-1)

        # 添加残差连接
        shortcut = x
        x = shortcut + self.drop_path(self.mlp(self.layer_norm(x)), training=training)
        return x

ShiftViT 块

Shift 块
图5:架构中的 Shift 块。 来源

每个阶段的架构中都有如图5所示的移位块。每个这样的块包含一个可变数量的堆叠 ShiftViT 块(如前面部分所构建)。

移位块后面是一个 PatchMerging 层,用于缩小特征输入。PatchMerging 层有助于模型的金字塔结构。

PatchMerging 层

这个层合并两个相邻的标记。这个层有助于在空间上缩小特征并在通道上增加特征。我们使用 Conv2D 层来合并补丁。

class PatchMerging(layers.Layer):
    """补丁合并层。

    Args:
        epsilon (float): epsilon常数。
    """

    def __init__(self, epsilon, **kwargs):
        super().__init__(**kwargs)
        self.epsilon = epsilon

    def build(self, input_shape):
        filters = 2 * input_shape[-1]
        self.reduction = layers.Conv2D(
            filters=filters, kernel_size=2, strides=2, padding="same", use_bias=False
        )
        self.layer_norm = layers.LayerNormalization(epsilon=self.epsilon)

    def call(self, x):
        # 在特征图上应用补丁合并算法
        x = self.layer_norm(x)
        x = self.reduction(x)
        return x

堆叠移位块

每个阶段将具有数量可变的堆叠 ShiftViT 块,如论文中所建议的。这是一个通用层,将包含堆叠的移位 vit 块和拼接层。将这两种操作(移位 ViT 块和拼接)结合在一起是我们为了更好的代码可重用性所做的设计选择。

# 注意:这一层在模型的不同阶段会有不同的堆叠深度
class StackedShiftBlocks(layers.Layer):
    """包含堆叠的 ShiftViTBlocks 的层。

    参数:
        epsilon (float): epsilon 常数。
        mlp_dropout_rate (float): MLP 块中使用的 dropout 率。
        num_shift_blocks (int): 此阶段的移位 vit 块数量。
        stochastic_depth_rate (float): 选择的最大 dropout 路径率。
        is_merge (boolean): 一个标志,用于确定是否在移位 vit 块之后使用 Patch Merge 层。
        num_div (int): 特征图的通道数划分。默认为 12。
        shift_pixel (int): 移位的像素数。默认为 1。
        mlp_expand_ratio (int): MLP 的初始密集层扩展的比例。默认为 2。
    """

    def __init__(
        self,
        epsilon,
        mlp_dropout_rate,
        num_shift_blocks,
        stochastic_depth_rate,
        is_merge,
        num_div=12,
        shift_pixel=1,
        mlp_expand_ratio=2,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.epsilon = epsilon
        self.mlp_dropout_rate = mlp_dropout_rate
        self.num_shift_blocks = num_shift_blocks
        self.stochastic_depth_rate = stochastic_depth_rate
        self.is_merge = is_merge
        self.num_div = num_div
        self.shift_pixel = shift_pixel
        self.mlp_expand_ratio = mlp_expand_ratio

    def build(self, input_shapes):
        # 计算随机深度概率。
        # 参考: https://keras.io/examples/vision/cct/#the-final-cct-model
        dpr = [
            x
            for x in np.linspace(
                start=0, stop=self.stochastic_depth_rate, num=self.num_shift_blocks
            )
        ]

        # 将移位块构建为 ShiftViT 块的列表
        self.shift_blocks = list()
        for num in range(self.num_shift_blocks):
            self.shift_blocks.append(
                ShiftViTBlock(
                    num_div=self.num_div,
                    epsilon=self.epsilon,
                    drop_path_prob=dpr[num],
                    mlp_dropout_rate=self.mlp_dropout_rate,
                    shift_pixel=self.shift_pixel,
                    mlp_expand_ratio=self.mlp_expand_ratio,
                )
            )
        if self.is_merge:
            self.patch_merge = PatchMerging(epsilon=self.epsilon)

    def call(self, x, training=False):
        for shift_block in self.shift_blocks:
            x = shift_block(x, training=training)
        if self.is_merge:
            x = self.patch_merge(x)
        return x

    # 由于这是一个自定义层,我们需要重写 get_config()
    # 以便模型可以在训练后方便地保存和加载
    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "epsilon": self.epsilon,
                "mlp_dropout_rate": self.mlp_dropout_rate,
                "num_shift_blocks": self.num_shift_blocks,
                "stochastic_depth_rate": self.stochastic_depth_rate,
                "is_merge": self.is_merge,
                "num_div": self.num_div,
                "shift_pixel": self.shift_pixel,
                "mlp_expand_ratio": self.mlp_expand_ratio,
            }
        )
        return config

ShiftViT 模型

构建 ShiftViT 自定义模型。

class ShiftViTModel(keras.Model):
    """ShiftViT模型。

    参数:
        data_augmentation (keras.Model): 数据增强模型。
        projected_dim (int): 图像块投影的维度。
        patch_size (int): 图像的块大小。
        num_shift_blocks_per_stages (list[int]): 每个阶段的所有移位块的数量列表。
        epsilon (float): epsilon常数。
        mlp_dropout_rate (float): 在MLP块中使用的丢弃率。
        stochastic_depth_rate (float): 最大丢弃率概率。
        num_div (int): 特征图通道的划分数量。默认为12。
        shift_pixel (int): 移位的像素数量。默认为1。
        mlp_expand_ratio (int): 初始MLP密集层扩展到的比例。默认为2。
    """


    def __init__(
        self,
        data_augmentation,
        projected_dim,
        patch_size,
        num_shift_blocks_per_stages,
        epsilon,
        mlp_dropout_rate,
        stochastic_depth_rate,
        num_div=12,
        shift_pixel=1,
        mlp_expand_ratio=2,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.data_augmentation = data_augmentation
        self.patch_projection = layers.Conv2D(
            filters=projected_dim,
            kernel_size=patch_size,
            strides=patch_size,
            padding="same",
        )
        self.stages = list()
        for index, num_shift_blocks in enumerate(num_shift_blocks_per_stages):
            if index == len(num_shift_blocks_per_stages) - 1:
                # This is the last stage, do not use the patch merge here.
                is_merge = False
            else:
                is_merge = True
            # Build the stages.
            self.stages.append(
                StackedShiftBlocks(
                    epsilon=epsilon,
                    mlp_dropout_rate=mlp_dropout_rate,
                    num_shift_blocks=num_shift_blocks,
                    stochastic_depth_rate=stochastic_depth_rate,
                    is_merge=is_merge,
                    num_div=num_div,
                    shift_pixel=shift_pixel,
                    mlp_expand_ratio=mlp_expand_ratio,
                )
            )
        self.global_avg_pool = layers.GlobalAveragePooling2D()

        self.classifier = layers.Dense(config.num_classes)

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "data_augmentation": self.data_augmentation,
                "patch_projection": self.patch_projection,
                "stages": self.stages,
                "global_avg_pool": self.global_avg_pool,
                "classifier": self.classifier,
            }
        )
        return config

    def _calculate_loss(self, data, training=False):
        (images, labels) = data

        # Augment the images
        augmented_images = self.data_augmentation(images, training=training)

        # Create patches and project the pathces.
        projected_patches = self.patch_projection(augmented_images)

        # Pass through the stages
        x = projected_patches
        for stage in self.stages:
            x = stage(x, training=training)

        # Get the logits.
        x = self.global_avg_pool(x)
        logits = self.classifier(x)

        # Calculate the loss and return it.
        total_loss = self.compiled_loss(labels, logits)
        return total_loss, labels, logits

    def train_step(self, inputs):
        with tf.GradientTape() as tape:
            total_loss, labels, logits = self._calculate_loss(
                data=inputs, training=True
            )

        # Apply gradients.
        train_vars = [
            self.data_augmentation.trainable_variables,
            self.patch_projection.trainable_variables,
            self.global_avg_pool.trainable_variables,
            self.classifier.trainable_variables,
        ]
        train_vars = train_vars + [stage.trainable_variables for stage in self.stages]

        # Optimize the gradients.
        grads = tape.gradient(total_loss, train_vars)
        trainable_variable_list = []
        for (grad, var) in zip(grads, train_vars):
            for g, v in zip(grad, var):
                trainable_variable_list.append((g, v))
        self.optimizer.apply_gradients(trainable_variable_list)

        # Update the metrics
        self.compiled_metrics.update_state(labels, logits)
        return {m.name: m.result() for m in self.metrics}

    def test_step(self, data):
        _, labels, logits = self._calculate_loss(data=data, training=False)

        # Update the metrics
        self.compiled_metrics.update_state(labels, logits)
        return {m.name: m.result() for m in self.metrics}

    def call(self, images):
        augmented_images = self.data_augmentation(images)
        x = self.patch_projection(augmented_images)
        for stage in self.stages:
            x = stage(x, training=False)
        x = self.global_avg_pool(x)
        logits = self.classifier(x)
        return logits

实例化模型

model = ShiftViTModel(
    data_augmentation=get_augmentation_model(),
    projected_dim=config.projected_dim,
    patch_size=config.patch_size,
    num_shift_blocks_per_stages=config.num_shift_blocks_per_stages,
    epsilon=config.epsilon,
    mlp_dropout_rate=config.mlp_dropout_rate,
    stochastic_depth_rate=config.stochastic_depth_rate,
    num_div=config.num_div,
    shift_pixel=config.shift_pixel,
    mlp_expand_ratio=config.mlp_expand_ratio,
)

学习率调度

在许多实验中,我们希望通过缓慢增加学习率来热身模型,然后通过缓慢减少学习率来冷却模型。在预热余弦衰减中,学习率在线性增加的预热步骤后,以余弦衰减的方式衰减。

# 一些代码来源于:
# https://www.kaggle.com/ashusma/training-rfcx-tensorflow-tpu-effnet-b2.
class WarmUpCosine(keras.optimizers.schedules.LearningRateSchedule):
    """一个使用预热余弦衰减调度的学习率调度类。"""

    def __init__(self, lr_start, lr_max, warmup_steps, total_steps):
        """
        参数:
            lr_start: 初始学习率
            lr_max: 在预热步骤中学习率应增加到的最大学习率
            warmup_steps: 模型热身的步骤数
            total_steps: 模型训练的总步骤数
        """
        super().__init__()
        self.lr_start = lr_start
        self.lr_max = lr_max
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps
        self.pi = tf.constant(np.pi)

    def __call__(self, step):
        # 检查总步骤数是否大于预热步骤。如果不是,则抛出值错误。
        if self.total_steps < self.warmup_steps:
            raise ValueError(
                f"总步骤数 {self.total_steps} 必须"
                + f"大于或等于预热步骤数 {self.warmup_steps}。"
            )

        # `cos_annealed_lr` 是一个图形,从初始步骤增加到 1,直到预热步骤。之后,该图形在最后步骤标记处衰减到 -1。
        cos_annealed_lr = tf.cos(
            self.pi
            * (tf.cast(step, tf.float32) - self.warmup_steps)
            / tf.cast(self.total_steps - self.warmup_steps, tf.float32)
        )

        # 将 `cos_annealed_lr` 图形的均值移到 1。现在该图形从 0 变为 2。用 0.5 进行归一化,使其现在从 0 变为 1。通过归一化图形,我们用 `lr_max` 进行缩放,使其从 0 变为 `lr_max`
        learning_rate = 0.5 * self.lr_max * (1 + cos_annealed_lr)

        # 检查 warmup_steps 是否大于 0。
        if self.warmup_steps > 0:
            # 检查 lr_max 是否大于 lr_start。如果不是,则抛出值错误。
            if self.lr_max < self.lr_start:
                raise ValueError(
                    f"lr_start {self.lr_start} 必须小于或"
                    + f"等于 lr_max {self.lr_max}。"
                )

            # 计算学习率在预热调度中应增加的斜率。斜率的公式是 m = ((b-a)/steps)
            slope = (self.lr_max - self.lr_start) / self.warmup_steps

            # 使用直线的公式 (y = mx+c) 构建预热调度
            warmup_rate = slope * tf.cast(step, tf.float32) + self.lr_start

            # 当当前步骤小于预热步骤时,获取直线图。当当前步骤大于预热步骤时,获取缩放后的余弦图。
            learning_rate = tf.where(
                step < self.warmup_steps, warmup_rate, learning_rate
            )

        # 当当前步骤大于总步骤时,返回 0,否则返回计算出的图形。
        return tf.where(
            step > self.total_steps, 0.0, learning_rate, name="learning_rate"
        )

    def get_config(self):
        config = {
            "lr_start": self.lr_start,
            "lr_max": self.lr_max,
            "total_steps": self.total_steps,
            "warmup_steps": self.warmup_steps,
        }
        return config

编译和训练模型

# pass sample data to the model so that input shape is available at the time of
# saving the model
sample_ds, _ = next(iter(train_ds))
model(sample_ds, training=False)

# Get the total number of steps for training.
total_steps = int((len(x_train) / config.batch_size) * config.epochs)

# Calculate the number of steps for warmup.
warmup_epoch_percentage = 0.15
warmup_steps = int(total_steps * warmup_epoch_percentage)

# Initialize the warmupcosine schedule.
scheduled_lrs = WarmUpCosine(
    lr_start=1e-5,
    lr_max=1e-3,
    warmup_steps=warmup_steps,
    total_steps=total_steps,
)

# Get the optimizer.
optimizer = tfa.optimizers.AdamW(
    learning_rate=scheduled_lrs, weight_decay=config.weight_decay
)

# Compile and pretrain the model.
model.compile(
    optimizer=optimizer,
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[
        keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
        keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),
    ],
)

# Train the model
history = model.fit(
    train_ds,
    epochs=config.epochs,
    validation_data=val_ds,
    callbacks=[
        keras.callbacks.EarlyStopping(
            monitor="val_accuracy",
            patience=5,
            mode="auto",
        )
    ],
)

# Evaluate the model with the test dataset.
print("TESTING")
loss, acc_top1, acc_top5 = model.evaluate(test_ds)
print(f"Loss: {loss:0.2f}")
print(f"Top 1 test accuracy: {acc_top1*100:0.2f}%")
print(f"Top 5 test accuracy: {acc_top5*100:0.2f}%")
Epoch 1/100
157/157 [==============================] - 72s 332ms/step - loss: 2.3844 - accuracy: 0.1444 - top-5-accuracy: 0.6051 - val_loss: 2.0984 - val_accuracy: 0.2610 - val_top-5-accuracy: 0.7638
Epoch 2/100
157/157 [==============================] - 49s 314ms/step - loss: 1.9457 - accuracy: 0.2893 - top-5-accuracy: 0.8103 - val_loss: 1.9459 - val_accuracy: 0.3356 - val_top-5-accuracy: 0.8614
Epoch 3/100
157/157 [==============================] - 50s 316ms/step - loss: 1.7093 - accuracy: 0.3810 - top-5-accuracy: 0.8761 - val_loss: 1.5349 - val_accuracy: 0.4585 - val_top-5-accuracy: 0.9045
Epoch 4/100
157/157 [==============================] - 49s 315ms/step - loss: 1.5473 - accuracy: 0.4374 - top-5-accuracy: 0.9090 - val_loss: 1.4257 - val_accuracy: 0.4862 - val_top-5-accuracy: 0.9298
Epoch 5/100
157/157 [==============================] - 50s 316ms/step - loss: 1.4316 - accuracy: 0.4816 - top-5-accuracy: 0.9243 - val_loss: 1.4032 - val_accuracy: 0.5092 - val_top-5-accuracy: 0.9362
Epoch 6/100
157/157 [==============================] - 50s 316ms/step - loss: 1.3588 - accuracy: 0.5131 - top-5-accuracy: 0.9333 - val_loss: 1.2893 - val_accuracy: 0.5411 - val_top-5-accuracy: 0.9457
Epoch 7/100
157/157 [==============================] - 50s 316ms/step - loss: 1.2894 - accuracy: 0.5385 - top-5-accuracy: 0.9410 - val_loss: 1.2922 - val_accuracy: 0.5416 - val_top-5-accuracy: 0.9432
Epoch 8/100
157/157 [==============================] - 49s 315ms/step - loss: 1.2388 - accuracy: 0.5568 - top-5-accuracy: 0.9468 - val_loss: 1.2100 - val_accuracy: 0.5733 - val_top-5-accuracy: 0.9545
Epoch 9/100
157/157 [==============================] - 49s 315ms/step - loss: 1.2043 - accuracy: 0.5698 - top-5-accuracy: 0.9491 - val_loss: 1.2166 - val_accuracy: 0.5675 - val_top-5-accuracy: 0.9520
Epoch 10/100
157/157 [==============================] - 49s 315ms/step - loss: 1.1694 - accuracy: 0.5861 - top-5-accuracy: 0.9528 - val_loss: 1.1738 - val_accuracy: 0.5883 - val_top-5-accuracy: 0.9541
Epoch 11/100
157/157 [==============================] - 50s 316ms/step - loss: 1.1290 - accuracy: 0.5994 - top-5-accuracy: 0.9575 - val_loss: 1.1161 - val_accuracy: 0.6064 - val_top-5-accuracy: 0.9618
Epoch 12/100
157/157 [==============================] - 50s 316ms/step - loss: 1.0861 - accuracy: 0.6157 - top-5-accuracy: 0.9602 - val_loss: 1.1220 - val_accuracy: 0.6133 - val_top-5-accuracy: 0.9576
Epoch 13/100
157/157 [==============================] - 49s 315ms/step - loss: 1.0766 - accuracy: 0.6178 - top-5-accuracy: 0.9612 - val_loss: 1.0108 - val_accuracy: 0.6402 - val_top-5-accuracy: 0.9681
Epoch 14/100
157/157 [==============================] - 49s 315ms/step - loss: 1.0179 - accuracy: 0.6416 - top-5-accuracy: 0.9658 - val_loss: 1.0196 - val_accuracy: 0.6405 - val_top-5-accuracy: 0.9667
Epoch 15/100
157/157 [==============================] - 50s 316ms/step - loss: 1.0028 - accuracy: 0.6470 - top-5-accuracy: 0.9678 - val_loss: 1.0113 - val_accuracy: 0.6415 - val_top-5-accuracy: 0.9672
Epoch 16/100
157/157 [==============================] - 50s 316ms/step - loss: 0.9613 - accuracy: 0.6611 - top-5-accuracy: 0.9710 - val_loss: 1.0516 - val_accuracy: 0.6406 - val_top-5-accuracy: 0.9596
Epoch 17/100
157/157 [==============================] - 50s 316ms/step - loss: 0.9262 - accuracy: 0.6740 - top-5-accuracy: 0.9729 - val_loss: 0.9010 - val_accuracy: 0.6844 - val_top-5-accuracy: 0.9750
Epoch 18/100
157/157 [==============================] - 50s 316ms/step - loss: 0.8768 - accuracy: 0.6916 - top-5-accuracy: 0.9769 - val_loss: 0.8862 - val_accuracy: 0.6908 - val_top-5-accuracy: 0.9767
Epoch 19/100
157/157 [==============================] - 49s 315ms/step - loss: 0.8595 - accuracy: 0.6984 - top-5-accuracy: 0.9768 - val_loss: 0.8732 - val_accuracy: 0.6982 - val_top-5-accuracy: 0.9738
Epoch 20/100
157/157 [==============================] - 50s 317ms/step - loss: 0.8252 - accuracy: 0.7103 - top-5-accuracy: 0.9793 - val_loss: 0.9330 - val_accuracy: 0.6745 - val_top-5-accuracy: 0.9718
Epoch 21/100
157/157 [==============================] - 51s 322ms/step - loss: 0.8003 - accuracy: 0.7180 - top-5-accuracy: 0.9814 - val_loss: 0.8912 - val_accuracy: 0.6948 - val_top-5-accuracy: 0.9728
Epoch 22/100
157/157 [==============================] - 51s 326ms/step - loss: 0.7651 - accuracy: 0.7317 - top-5-accuracy: 0.9829 - val_loss: 0.7894 - val_accuracy: 0.7277 - val_top-5-accuracy: 0.9791
Epoch 23/100
157/157 [==============================] - 52s 328ms/step - loss: 0.7372 - accuracy: 0.7415 - top-5-accuracy: 0.9843 - val_loss: 0.7752 - val_accuracy: 0.7284 - val_top-5-accuracy: 0.9804
Epoch 24/100
157/157 [==============================] - 51s 327ms/step - loss: 0.7324 - accuracy: 0.7423 - top-5-accuracy: 0.9852 - val_loss: 0.7949 - val_accuracy: 0.7340 - val_top-5-accuracy: 0.9792
Epoch 25/100
157/157 [==============================] - 51s 323ms/step - loss: 0.7051 - accuracy: 0.7512 - top-5-accuracy: 0.9858 - val_loss: 0.7967 - val_accuracy: 0.7280 - val_top-5-accuracy: 0.9787
Epoch 26/100
157/157 [==============================] - 51s 323ms/step - loss: 0.6832 - accuracy: 0.7577 - top-5-accuracy: 0.9870 - val_loss: 0.7840 - val_accuracy: 0.7322 - val_top-5-accuracy: 0.9807
Epoch 27/100
157/157 [==============================] - 51s 322ms/step - loss: 0.6609 - accuracy: 0.7654 - top-5-accuracy: 0.9877 - val_loss: 0.7447 - val_accuracy: 0.7434 - val_top-5-accuracy: 0.9816
Epoch 28/100
157/157 [==============================] - 50s 319ms/step - loss: 0.6495 - accuracy: 0.7724 - top-5-accuracy: 0.9883 - val_loss: 0.7885 - val_accuracy: 0.7280 - val_top-5-accuracy: 0.9817
Epoch 29/100
157/157 [==============================] - 50s 317ms/step - loss: 0.6491 - accuracy: 0.7707 - top-5-accuracy: 0.9885 - val_loss: 0.7539 - val_accuracy: 0.7458 - val_top-5-accuracy: 0.9821
Epoch 30/100
157/157 [==============================] - 50s 317ms/step - loss: 0.6213 - accuracy: 0.7823 - top-5-accuracy: 0.9888 - val_loss: 0.7571 - val_accuracy: 0.7470 - val_top-5-accuracy: 0.9815
Epoch 31/100
157/157 [==============================] - 50s 318ms/step - loss: 0.5976 - accuracy: 0.7902 - top-5-accuracy: 0.9906 - val_loss: 0.7430 - val_accuracy: 0.7508 - val_top-5-accuracy: 0.9817
Epoch 32/100
157/157 [==============================] - 50s 318ms/step - loss: 0.5932 - accuracy: 0.7898 - top-5-accuracy: 0.9910 - val_loss: 0.7545 - val_accuracy: 0.7469 - val_top-5-accuracy: 0.9793
Epoch 33/100
157/157 [==============================] - 50s 318ms/step - loss: 0.5977 - accuracy: 0.7850 - top-5-accuracy: 0.9913 - val_loss: 0.7200 - val_accuracy: 0.7569 - val_top-5-accuracy: 0.9830
Epoch 34/100
157/157 [==============================] - 50s 317ms/step - loss: 0.5552 - accuracy: 0.8041 - top-5-accuracy: 0.9920 - val_loss: 0.7377 - val_accuracy: 0.7552 - val_top-5-accuracy: 0.9818
Epoch 35/100
157/157 [==============================] - 50s 319ms/step - loss: 0.5509 - accuracy: 0.8056 - top-5-accuracy: 0.9921 - val_loss: 0.8125 - val_accuracy: 0.7331 - val_top-5-accuracy: 0.9782
Epoch 36/100
157/157 [==============================] - 50s 317ms/step - loss: 0.5296 - accuracy: 0.8116 - top-5-accuracy: 0.9933 - val_loss: 0.6900 - val_accuracy: 0.7680 - val_top-5-accuracy: 0.9849
Epoch 37/100
157/157 [==============================] - 50s 316ms/step - loss: 0.5151 - accuracy: 0.8170 - top-5-accuracy: 0.9941 - val_loss: 0.7275 - val_accuracy: 0.7610 - val_top-5-accuracy: 0.9841
Epoch 38/100
157/157 [==============================] - 50s 317ms/step - loss: 0.5069 - accuracy: 0.8217 - top-5-accuracy: 0.9936 - val_loss: 0.7067 - val_accuracy: 0.7703 - val_top-5-accuracy: 0.9835
Epoch 39/100
157/157 [==============================] - 50s 318ms/step - loss: 0.4771 - accuracy: 0.8304 - top-5-accuracy: 0.9945 - val_loss: 0.7110 - val_accuracy: 0.7668 - val_top-5-accuracy: 0.9836
Epoch 40/100
157/157 [==============================] - 50s 317ms/step - loss: 0.4675 - accuracy: 0.8350 - top-5-accuracy: 0.9956 - val_loss: 0.7130 - val_accuracy: 0.7688 - val_top-5-accuracy: 0.9829
Epoch 41/100
157/157 [==============================] - 50s 319ms/step - loss: 0.4586 - accuracy: 0.8382 - top-5-accuracy: 0.9959 - val_loss: 0.7331 - val_accuracy: 0.7598 - val_top-5-accuracy: 0.9806
Epoch 42/100
157/157 [==============================] - 50s 318ms/step - loss: 0.4558 - accuracy: 0.8380 - top-5-accuracy: 0.9959 - val_loss: 0.7187 - val_accuracy: 0.7722 - val_top-5-accuracy: 0.9832
Epoch 43/100
157/157 [==============================] - 50s 320ms/step - loss: 0.4356 - accuracy: 0.8450 - top-5-accuracy: 0.9958 - val_loss: 0.7162 - val_accuracy: 0.7693 - val_top-5-accuracy: 0.9850
Epoch 44/100
157/157 [==============================] - 49s 314ms/step - loss: 0.4425 - accuracy: 0.8433 - top-5-accuracy: 0.9958 - val_loss: 0.7061 - val_accuracy: 0.7698 - val_top-5-accuracy: 0.9853
Epoch 45/100
157/157 [==============================] - 49s 314ms/step - loss: 0.4072 - accuracy: 0.8551 - top-5-accuracy: 0.9967 - val_loss: 0.7025 - val_accuracy: 0.7820 - val_top-5-accuracy: 0.9848
Epoch 46/100
157/157 [==============================] - 49s 314ms/step - loss: 0.3865 - accuracy: 0.8644 - top-5-accuracy: 0.9970 - val_loss: 0.7178 - val_accuracy: 0.7740 - val_top-5-accuracy: 0.9844
Epoch 47/100
157/157 [==============================] - 49s 313ms/step - loss: 0.3718 - accuracy: 0.8694 - top-5-accuracy: 0.9973 - val_loss: 0.7216 - val_accuracy: 0.7768 - val_top-5-accuracy: 0.9828
Epoch 48/100
157/157 [==============================] - 49s 314ms/step - loss: 0.3733 - accuracy: 0.8673 - top-5-accuracy: 0.9970 - val_loss: 0.7440 - val_accuracy: 0.7713 - val_top-5-accuracy: 0.9841
Epoch 49/100
157/157 [==============================] - 49s 313ms/step - loss: 0.3531 - accuracy: 0.8741 - top-5-accuracy: 0.9979 - val_loss: 0.7220 - val_accuracy: 0.7738 - val_top-5-accuracy: 0.9848
Epoch 50/100
157/157 [==============================] - 49s 314ms/step - loss: 0.3502 - accuracy: 0.8738 - top-5-accuracy: 0.9980 - val_loss: 0.7245 - val_accuracy: 0.7734 - val_top-5-accuracy: 0.9836
TESTING
40/40 [==============================] - 2s 56ms/step - loss: 0.7336 - accuracy: 0.7638 - top-5-accuracy: 0.9855
Loss: 0.73
Top 1 test accuracy: 76.38%
Top 5 test accuracy: 98.55%

保存训练好的模型

由于我们通过子类化创建了模型,因此无法以HDF5格式保存模型。

模型只能以TF SavedModel格式保存。一般来说,这也是保存模型的推荐格式。

model.save("ShiftViT")

模型推理

下载推理样本数据

!wget -q 'https://tinyurl.com/2p9483sw' -O inference_set.zip
!unzip -q inference_set.zip

加载保存的模型

# 自定义对象在模型保存时未被包含。
# 在加载时,需要传入这些对象以重建模型
saved_model = tf.keras.models.load_model(
    "ShiftViT",
    custom_objects={"WarmUpCosine": WarmUpCosine, "AdamW": tfa.optimizers.AdamW},
)

推理的实用函数

def process_image(img_path):
    # 从字符串路径读取图像文件
    img = tf.io.read_file(img_path)

    # 将jpeg解码为uint8张量
    img = tf.io.decode_jpeg(img, channels=3)

    # 调整图像大小以匹配模型接受的输入大小
    # 使用`method`为`nearest`以保留传递给`resize()`的输入的dtype
    img = tf.image.resize(
        img, [config.input_shape[0], config.input_shape[1]], method="nearest"
    )
    return img


def create_tf_dataset(image_dir):
    data_dir = pathlib.Path(image_dir)

    # 使用图像目录创建tf.data数据集
    predict_ds = tf.data.Dataset.list_files(str(data_dir / "*.jpg"), shuffle=False)

    # 使用map将字符串路径转换为uint8图像张量
    # 设置`num_parallel_calls`有助于并行处理多张图像
    predict_ds = predict_ds.map(process_image, num_parallel_calls=AUTO)

    # 创建Prefetch数据集以获得更好的延迟和吞吐量
    predict_ds = predict_ds.batch(config.tf_ds_batch_size).prefetch(AUTO)
    return predict_ds


def predict(predict_ds):
    # ShiftViT模型返回logits(非标准化预测)
    logits = saved_model.predict(predict_ds)

    # 通过调用softmax()对预测进行归一化
    probabilities = tf.nn.softmax(logits)
    return probabilities


def get_predicted_class(probabilities):
    pred_label = np.argmax(probabilities)
    predicted_class = config.label_map[pred_label]
    return predicted_class


def get_confidence_scores(probabilities):
    # 获取按降序排列的概率分数的索引
    labels = np.argsort(probabilities)[::-1]
    confidences = {
        config.label_map[label]: np.round((probabilities[label]) * 100, 2)
        for label in labels
    }
    return confidences

获取预测结果

img_dir = "inference_set"
predict_ds = create_tf_dataset(img_dir)
probabilities = predict(predict_ds)
print(f"probabilities: {probabilities[0]}")
confidences = get_confidence_scores(probabilities[0])
print(confidences)
1/1 [==============================] - 2s 2s/step
probabilities: [8.7329084e-01 1.3162658e-03 6.1781306e-05 1.9132349e-05 4.4482469e-05
 1.8182898e-06 2.2834571e-05 1.1466043e-05 1.2504059e-01 1.9084632e-04]
{'airplane': 87.33, 'ship': 12.5, 'automobile': 0.13, 'truck': 0.02, 'bird': 0.01, 'deer': 0.0, 'frog': 0.0, 'cat': 0.0, 'horse': 0.0, 'dog': 0.0}

查看预测结果

plt.figure(figsize=(10, 10))
for images in predict_ds:
    for i in range(min(6, probabilities.shape[0])):
        ax = plt.subplot(3, 3, i + 1)
        plt.imshow(images[i].numpy().astype("uint8"))
        predicted_class = get_predicted_class(probabilities[i])
        plt.title(predicted_class)
        plt.axis("off")

png


结论

这篇论文最有影响力的贡献不是新颖的架构,而是 无注意力训练的层次化ViTs可以表现得相当不错的想法。这 提出了一个问题,即注意力对于ViTs性能的重要性如何。

对于好奇的读者,我们建议阅读 ConvNexT论文,该论文更多地关注于训练 范式和ViTs的结构细节,而不是提供基于注意力的新颖架构。

致谢:

示例可在HuggingFace上获取

已训练模型 演示
通用徽章 通用徽章