代码示例 / 计算机视觉 / 提炼视觉变换器

提炼视觉变换器

作者: Sayak Paul
创建日期: 2022/04/05
最后修改: 2022/04/08
描述: 通过注意力机制对视觉变换器进行蒸馏。

在 Colab 中查看 GitHub 源码


介绍

在原始的 视觉变换器 (ViT) 论文中 (Dosovitskiy et al.), 作者得出结论,为了与卷积神经网络 (CNN) 的表现相当, ViT 需要在更大的数据集上进行预训练。数据集越大越好。这主要是 由于 ViT 架构缺乏归纳偏置——与 CNN 不同, 它们没有利用局部性的层。在后续论文中 (Steiner et al.), 作者展示了通过更强的正则化和更长的训练时间 可以显著提高 ViT 的性能。

许多团队提出了不同的方法来解决 ViT 训练的数据密集性问题。 其中一种方法在 数据高效图像变换器(DeiT)论文中展示(Touvron et al.)。作者提出了一种特定于基于变换器的视觉模型的蒸馏技术。DeiT 是第一批展示在不使用更大数据集的情况下很好地训练 ViT 的工作之一。

在本示例中,我们实现了 Deit 提出的蒸馏流程。这 需要我们稍微调整原始的 ViT 架构并编写自定义训练循环以实现蒸馏流程。

要运行该示例,您需要安装 TensorFlow Addons,可以使用以下命令安装:

pip install tensorflow-addons

为了便于浏览此示例,您应该了解 ViT 和知识蒸馏的工作原理。如果您需要复习,以下资源是很好的参考:


导入

from typing import List

import tensorflow as tf
import tensorflow_addons as tfa
import tensorflow_datasets as tfds
import tensorflow_hub as hub
from tensorflow import keras
from tensorflow.keras import layers

tfds.disable_progress_bar()
tf.keras.utils.set_random_seed(42)

常量

# 模型
MODEL_TYPE = "deit_distilled_tiny_patch16_224"
RESOLUTION = 224
PATCH_SIZE = 16
NUM_PATCHES = (RESOLUTION // PATCH_SIZE) ** 2
LAYER_NORM_EPS = 1e-6
PROJECTION_DIM = 192
NUM_HEADS = 3
NUM_LAYERS = 12
MLP_UNITS = [
    PROJECTION_DIM * 4,
    PROJECTION_DIM,
]
DROPOUT_RATE = 0.0
DROP_PATH_RATE = 0.1

# 训练
NUM_EPOCHS = 20
BASE_LR = 0.0005
WEIGHT_DECAY = 0.0001

# 数据
BATCH_SIZE = 256
AUTO = tf.data.AUTOTUNE
NUM_CLASSES = 5

您可能注意到 DROPOUT_RATE 被设置为 0.0。实现中使用了 Dropout 以保持完整性。对于较小的模型(如本示例中使用的模型),您不需要它,但对于较大的模型,使用 Dropout 有助于提升性能。


加载 tf_flowers 数据集并准备预处理工具

作者使用了一系列不同的增强技术,包括 MixUp (Zhang et al.), RandAugment (Cubuk et al.), 等等。然而,为了保持示例的简单性,我们将其忽略。

def preprocess_dataset(is_training=True):
    def fn(image, label):
        if is_training:
            # 调整为更大的空间分辨率并进行随机裁剪。
            image = tf.image.resize(image, (RESOLUTION + 20, RESOLUTION + 20))
            image = tf.image.random_crop(image, (RESOLUTION, RESOLUTION, 3))
            image = tf.image.random_flip_left_right(image)
        else:
            image = tf.image.resize(image, (RESOLUTION, RESOLUTION))
        label = tf.one_hot(label, depth=NUM_CLASSES)
        return image, label

    return fn


def prepare_dataset(dataset, is_training=True):
    if is_training:
        dataset = dataset.shuffle(BATCH_SIZE * 10)
    dataset = dataset.map(preprocess_dataset(is_training), num_parallel_calls=AUTO)
    return dataset.batch(BATCH_SIZE).prefetch(AUTO)


train_dataset, val_dataset = tfds.load(
    "tf_flowers", split=["train[:90%]", "train[90%:]"], as_supervised=True
)
num_train = train_dataset.cardinality()
num_val = val_dataset.cardinality()
print(f"训练样本的数量: {num_train}")
print(f"验证样本的数量: {num_val}")

train_dataset = prepare_dataset(train_dataset, is_training=True)
val_dataset = prepare_dataset(val_dataset, is_training=False)
训练样本数量: 3303
验证样本数量: 367

实现 ViT 的 DeiT 变体

由于 DeiT 是 ViT 的扩展,首先实现 ViT 然后再扩展以支持 DeiT 的组件是有意义的。

首先,我们将实现一个用于随机深度的层 (Huang et al.) 在 DeiT 中用于正则化。

# 来源于: github.com:rwightman/pytorch-image-models.
class StochasticDepth(layers.Layer):
    def __init__(self, drop_prop, **kwargs):
        super().__init__(**kwargs)
        self.drop_prob = drop_prop

    def call(self, x, training=True):
        if training:
            keep_prob = 1 - self.drop_prob
            shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
            random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
            random_tensor = tf.floor(random_tensor)
            return (x / keep_prob) * random_tensor
        return x

现在,我们将实现 MLP 和 Transformer 块。

def mlp(x, dropout_rate: float, hidden_units: List):
    """Transformer 块的前馈神经网络."""
    # 遍历隐藏层单元并
    # 添加 Dense => Dropout.
    for (idx, units) in enumerate(hidden_units):
        x = layers.Dense(
            units,
            activation=tf.nn.gelu if idx == 0 else None,
        )(x)
        x = layers.Dropout(dropout_rate)(x)
    return x


def transformer(drop_prob: float, name: str) -> keras.Model:
    """带预规范化的 Transformer 块."""
    num_patches = NUM_PATCHES + 2 if "distilled" in MODEL_TYPE else NUM_PATCHES + 1
    encoded_patches = layers.Input((num_patches, PROJECTION_DIM))

    # 层归一化 1.
    x1 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(encoded_patches)

    # 多头自注意力层 1.
    attention_output = layers.MultiHeadAttention(
        num_heads=NUM_HEADS,
        key_dim=PROJECTION_DIM,
        dropout=DROPOUT_RATE,
    )(x1, x1)
    attention_output = (
        StochasticDepth(drop_prob)(attention_output) if drop_prob else attention_output
    )

    # 跳跃连接 1.
    x2 = layers.Add()([attention_output, encoded_patches])

    # 层归一化 2.
    x3 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x2)

    # MLP 层 1.
    x4 = mlp(x3, hidden_units=MLP_UNITS, dropout_rate=DROPOUT_RATE)
    x4 = StochasticDepth(drop_prob)(x4) if drop_prob else x4

    # 跳跃连接 2.
    outputs = layers.Add()([x2, x4])

    return keras.Model(encoded_patches, outputs, name=name)

我们现在将实现一个 ViTClassifier 类,基于我们刚刚开发的组件。在这里,我们将遵循 ViT 论文中使用的原始池化策略——使用一个类标记,并使用与之对应的特征表示进行分类。

class ViTClassifier(keras.Model):
    """视觉 Transformer 基类。"""

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        # 切片 + 线性投影 + 变形。
        self.projection = keras.Sequential(
            [
                layers.Conv2D(
                    filters=PROJECTION_DIM,
                    kernel_size=(PATCH_SIZE, PATCH_SIZE),
                    strides=(PATCH_SIZE, PATCH_SIZE),
                    padding="VALID",
                    name="conv_projection",
                ),
                layers.Reshape(
                    target_shape=(NUM_PATCHES, PROJECTION_DIM),
                    name="flatten_projection",
                ),
            ],
            name="projection",
        )

        # 位置嵌入。
        init_shape = (
            1,
            NUM_PATCHES + 1,
            PROJECTION_DIM,
        )
        self.positional_embedding = tf.Variable(
            tf.zeros(init_shape), name="pos_embedding"
        )

        # Transformer 块。
        dpr = [x for x in tf.linspace(0.0, DROP_PATH_RATE, NUM_LAYERS)]
        self.transformer_blocks = [
            transformer(drop_prob=dpr[i], name=f"transformer_block_{i}")
            for i in range(NUM_LAYERS)
        ]

        # CLS 令牌。
        initial_value = tf.zeros((1, 1, PROJECTION_DIM))
        self.cls_token = tf.Variable(
            initial_value=initial_value, trainable=True, name="cls"
        )

        # 其他层。
        self.dropout = layers.Dropout(DROPOUT_RATE)
        self.layer_norm = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)
        self.head = layers.Dense(
            NUM_CLASSES,
            name="classification_head",
        )

    def call(self, inputs, training=True):
        n = tf.shape(inputs)[0]

        # 创建切片并投影切片。
        projected_patches = self.projection(inputs)

        # 如有需要,添加类令牌。
        cls_token = tf.tile(self.cls_token, (n, 1, 1))
        cls_token = tf.cast(cls_token, projected_patches.dtype)
        projected_patches = tf.concat([cls_token, projected_patches], axis=1)

        # 将位置嵌入添加到投影切片中。
        encoded_patches = (
            self.positional_embedding + projected_patches
        )  # (B, number_patches, projection_dim)
        encoded_patches = self.dropout(encoded_patches)

        # 遍历层数并堆叠 Transformer 块。
        for transformer_module in self.transformer_blocks:
            # 添加一个 Transformer 块。
            encoded_patches = transformer_module(encoded_patches)

        # 最终层归一化。
        representation = self.layer_norm(encoded_patches)

        # 池化表示。
        encoded_patches = representation[:, 0]

        # 分类头。
        output = self.head(encoded_patches)
        return output

这个类可以独立用作 ViT,并且是端到端可训练的。只需在 MODEL_TYPE 中去掉 distilled 词组,它就可以和 vit_tiny = ViTClassifier() 一起工作。现在让我们扩展到 DeiT。以下图示展示了 DeiT 的示意图(摘自 DeiT 论文):

除了类令牌,DeiT 还有另一个用于蒸馏的令牌。在蒸馏过程中,与类令牌对应的 logits 会与真实标签进行比较,而与蒸馏令牌对应的 logits 会与教师的预测进行比较。

class ViTDistilled(ViTClassifier):
    def __init__(self, regular_training=False, **kwargs):
        super().__init__(**kwargs)
        self.num_tokens = 2
        self.regular_training = regular_training

        # CLS 和蒸馏令牌,位置嵌入。
        init_value = tf.zeros((1, 1, PROJECTION_DIM))
        self.dist_token = tf.Variable(init_value, name="dist_token")
        self.positional_embedding = tf.Variable(
            tf.zeros(
                (
                    1,
                    NUM_PATCHES + self.num_tokens,
                    PROJECTION_DIM,
                )
            ),
            name="pos_embedding",
        )

        # 头部层。
        self.head = layers.Dense(
            NUM_CLASSES,
            name="classification_head",
        )
        self.head_dist = layers.Dense(
            NUM_CLASSES,
            name="distillation_head",
        )

    def call(self, inputs, training=True):
        n = tf.shape(inputs)[0]

        # 创建补丁并对补丁进行投影。
        projected_patches = self.projection(inputs)

        # 添加令牌。
        cls_token = tf.tile(self.cls_token, (n, 1, 1))
        dist_token = tf.tile(self.dist_token, (n, 1, 1))
        cls_token = tf.cast(cls_token, projected_patches.dtype)
        dist_token = tf.cast(dist_token, projected_patches.dtype)
        projected_patches = tf.concat(
            [cls_token, dist_token, projected_patches], axis=1
        )

        # 将位置嵌入添加到投影的补丁中。
        encoded_patches = (
            self.positional_embedding + projected_patches
        )  # (B, number_patches, projection_dim)
        encoded_patches = self.dropout(encoded_patches)

        # 遍历层数并堆叠 Transformer 模块。
        for transformer_module in self.transformer_blocks:
            # 添加一个 Transformer 块。
            encoded_patches = transformer_module(encoded_patches)

        # 最终的层归一化。
        representation = self.layer_norm(encoded_patches)

        # 分类头。
        x, x_dist = (
            self.head(representation[:, 0]),
            self.head_dist(representation[:, 1]),
        )

        if not training or self.regular_training:
            # 在标准训练/微调期间,推理时平均分类器的预测。
            return (x + x_dist) / 2

        elif training:
            # 仅在蒸馏模式下训练时返回单独的分类预测。
            return x, x_dist

让我们验证一下 ViTDistilled 类是否可以按预期初始化和调用。

deit_tiny_distilled = ViTDistilled()

dummy_inputs = tf.ones((2, 224, 224, 3))
outputs = deit_tiny_distilled(dummy_inputs, training=False)
print(outputs.shape)
(2, 5)

实现训练器

与标准知识蒸馏中发生的情况不同 (Hinton 等人), 在知识蒸馏中使用温度缩放的 softmax 以及 KL 散度, DeiT 作者使用以下损失函数:

这里,

  • CE 是交叉熵
  • psi 是 softmax 函数
  • Z_s 表示学生预测
  • y 表示真实标签
  • y_t 表示教师预测
class DeiT(keras.Model):
    # 参考:
    # https://keras.io/examples/vision/knowledge_distillation/
    def __init__(self, student, teacher, **kwargs):
        super().__init__(**kwargs)
        self.student = student
        self.teacher = teacher

        self.student_loss_tracker = keras.metrics.Mean(name="student_loss")
        self.dist_loss_tracker = keras.metrics.Mean(name="distillation_loss")

    @property
    def metrics(self):
        metrics = super().metrics
        metrics.append(self.student_loss_tracker)
        metrics.append(self.dist_loss_tracker)
        return metrics

    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
    ):
        super().compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn

    def train_step(self, data):
        # 解包数据。
        x, y = data

        # 教师的前向传播
        teacher_predictions = tf.nn.softmax(self.teacher(x, training=False), -1)
        teacher_predictions = tf.argmax(teacher_predictions, -1)

        with tf.GradientTape() as tape:
            # 学生的前向传播。
            cls_predictions, dist_predictions = self.student(x / 255.0, training=True)

            # 计算损失。
            student_loss = self.student_loss_fn(y, cls_predictions)
            distillation_loss = self.distillation_loss_fn(
                teacher_predictions, dist_predictions
            )
            loss = (student_loss + distillation_loss) / 2

        # 计算梯度。
        trainable_vars = self.student.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # 更新权重。
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # 更新在`compile()`中配置的指标。
        student_predictions = (cls_predictions + dist_predictions) / 2
        self.compiled_metrics.update_state(y, student_predictions)
        self.dist_loss_tracker.update_state(distillation_loss)
        self.student_loss_tracker.update_state(student_loss)

        # 返回性能字典。
        results = {m.name: m.result() for m in self.metrics}
        return results

    def test_step(self, data):
        # 解包数据。
        x, y = data

        # 计算预测。
        y_prediction = self.student(x / 255.0, training=False)

        # 计算损失。
        student_loss = self.student_loss_fn(y, y_prediction)

        # 更新指标。
        self.compiled_metrics.update_state(y, y_prediction)
        self.student_loss_tracker.update_state(student_loss)

        # 返回性能字典。
        results = {m.name: m.result() for m in self.metrics}
        return results

    def call(self, inputs):
        return self.student(inputs / 255.0, training=False)

加载教师模型

该模型基于 BiT 家族的 ResNets (Kolesnikov et al.) 在 tf_flowers 数据集上进行了微调。您可以参考 这个笔记本 了解训练的执行方式。教师模型大约有 2.12 亿个参数, 大约是学生模型的 40倍

!wget -q https://github.com/sayakpaul/deit-tf/releases/download/v0.1.0/bit_teacher_flowers.zip
!unzip -q bit_teacher_flowers.zip
bit_teacher_flowers = keras.models.load_model("bit_teacher_flowers")

通过蒸馏训练

deit_tiny = ViTDistilled()
deit_distiller = DeiT(student=deit_tiny, teacher=bit_teacher_flowers)

lr_scaled = (BASE_LR / 512) * BATCH_SIZE
deit_distiller.compile(
    optimizer=tfa.optimizers.AdamW(weight_decay=WEIGHT_DECAY, learning_rate=lr_scaled),
    metrics=["accuracy"],
    student_loss_fn=keras.losses.CategoricalCrossentropy(
        from_logits=True, label_smoothing=0.1
    ),
    distillation_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
)
_ = deit_distiller.fit(train_dataset, validation_data=val_dataset, epochs=NUM_EPOCHS)
Epoch 1/20
13/13 [==============================] - 44s 2s/step - accuracy: 0.2343 - student_loss: 2.2630 - distillation_loss: 1.7818 - val_accuracy: 0.2234 - val_student_loss: 1.6622 - val_distillation_loss: 0.0000e+00
Epoch 2/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.2150 - student_loss: 1.6377 - distillation_loss: 1.6138 - val_accuracy: 0.1907 - val_student_loss: 1.6150 - val_distillation_loss: 0.0000e+00
Epoch 3/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.2552 - student_loss: 1.6073 - distillation_loss: 1.5970 - val_accuracy: 0.1907 - val_student_loss: 1.6093 - val_distillation_loss: 0.0000e+00
Epoch 4/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.2564 - student_loss: 1.5954 - distillation_loss: 1.5902 - val_accuracy: 0.2997 - val_student_loss: 1.5958 - val_distillation_loss: 0.0000e+00
Epoch 5/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.2922 - student_loss: 1.5839 - distillation_loss: 1.5704 - val_accuracy: 0.3488 - val_student_loss: 1.5635 - val_distillation_loss: 0.0000e+00
Epoch 6/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.3815 - student_loss: 1.4865 - distillation_loss: 1.4551 - val_accuracy: 0.3815 - val_student_loss: 1.4975 - val_distillation_loss: 0.0000e+00
Epoch 7/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.4151 - student_loss: 1.4027 - distillation_loss: 1.3441 - val_accuracy: 0.3733 - val_student_loss: 1.4083 - val_distillation_loss: 0.0000e+00
Epoch 8/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.4423 - student_loss: 1.3616 - distillation_loss: 1.2877 - val_accuracy: 0.4005 - val_student_loss: 1.4014 - val_distillation_loss: 0.0000e+00
Epoch 9/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.4475 - student_loss: 1.3095 - distillation_loss: 1.2200 - val_accuracy: 0.4496 - val_student_loss: 1.3211 - val_distillation_loss: 0.0000e+00
Epoch 10/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.4959 - student_loss: 1.2638 - distillation_loss: 1.1508 - val_accuracy: 0.4932 - val_student_loss: 1.2839 - val_distillation_loss: 0.0000e+00
Epoch 11/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.5431 - student_loss: 1.2063 - distillation_loss: 1.0948 - val_accuracy: 0.5559 - val_student_loss: 1.1938 - val_distillation_loss: 0.0000e+00
Epoch 12/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.5771 - student_loss: 1.1742 - distillation_loss: 1.0461 - val_accuracy: 0.5695 - val_student_loss: 1.1362 - val_distillation_loss: 0.0000e+00
Epoch 13/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.5601 - student_loss: 1.1724 - distillation_loss: 1.0457 - val_accuracy: 0.5477 - val_student_loss: 1.1929 - val_distillation_loss: 0.0000e+00
Epoch 14/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.5777 - student_loss: 1.1717 - distillation_loss: 1.0378 - val_accuracy: 0.5777 - val_student_loss: 1.1171 - val_distillation_loss: 0.0000e+00
Epoch 15/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.6173 - student_loss: 1.1232 - distillation_loss: 0.9782 - val_accuracy: 0.5640 - val_student_loss: 1.1229 - val_distillation_loss: 0.0000e+00
Epoch 16/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.6237 - student_loss: 1.1091 - distillation_loss: 0.9627 - val_accuracy: 0.5886 - val_student_loss: 1.1371 - val_distillation_loss: 0.0000e+00
Epoch 17/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.6261 - student_loss: 1.0880 - distillation_loss: 0.9341 - val_accuracy: 0.6322 - val_student_loss: 1.0972 - val_distillation_loss: 0.0000e+00
Epoch 18/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.6427 - student_loss: 1.0688 - distillation_loss: 0.9117 - val_accuracy: 0.6431 - val_student_loss: 1.0548 - val_distillation_loss: 0.0000e+00
Epoch 19/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.6458 - student_loss: 1.0529 - distillation_loss: 0.8903 - val_accuracy: 0.6076 - val_student_loss: 1.0761 - val_distillation_loss: 0.0000e+00
Epoch 20/20
13/13 [==============================] - 16s 1s/step - accuracy: 0.6382 - student_loss: 1.0641 - distillation_loss: 0.9049 - val_accuracy: 0.6240 - val_student_loss: 1.0521 - val_distillation_loss: 0.0000e+00

如果我们从头开始训练相同的模型(ViTClassifier),使用完全相同的超参数,模型大约会得到59%的准确率。您可以修改以下代码以重现此结果:

vit_tiny = ViTClassifier()

inputs = keras.Input((RESOLUTION, RESOLUTION, 3))
x = keras.layers.Rescaling(scale=1./255)(inputs)
outputs = deit_tiny(x)
model = keras.Model(inputs, outputs)

model.compile(...)
model.fit(...)

注意事项

  • 通过使用蒸馏,我们有效地转移了基于CNN的教师模型的归纳偏差。
  • 有趣的是,这种蒸馏策略与CNN作为教师模型相比,使用Transformer的效果较差,如论文中所示。
  • 用于训练DeiT模型的正则化非常重要。
  • ViT模型使用多种不同的初始化器进行初始化,包括截断正态分布、随机正态分布、Glorot均匀分布等。如果您希望完全重现原始结果,请不要忘记良好初始化ViTs。
  • 如果您想探索在TensorFlow和Keras中预训练的DeiT模型及其微调代码,请查看这些模型在TF-Hub上的页面

致谢

  • Ross Wightman,感谢他保持 timm 的更新,并提供可读的实现。在将它们实现到TensorFlow的过程中,我参考了ViT和DeiT的实现。
  • Aritra Roy Gosthipaty 在另一个项目中实现了ViTClassifier的一些部分。
  • Google开发者专家 项目支持我使用GCP积分进行此示例的实验。

示例可在HuggingFace上获得:

训练模型 演示
Generic badge Generic badge