代码示例 / 计算机视觉 / 当递归遇上变压器

当递归遇上变压器

作者: Aritra Roy Gosthipaty, Suvaditya Mukherjee
创建日期: 2023/03/12
最后修改日期: 2023/03/12
描述: 使用时间潜在瓶颈网络进行图像分类。

在 Colab 中查看 GitHub 源码


介绍

一个简单的递归神经网络(RNN)表现出强烈的归纳偏差,倾向于学习 时间压缩表示公式 1 显示了递归公式, 其中 h_t 是整个输入序列 x 的压缩表示(一个单一向量)。

RNN 的方程
公式 1: 递归方程。 (来源: Aritra 和 Suvaditya)

另一方面,变压器(Vaswani 等人)对学习时间压缩表示几乎没有归纳偏差。 变压器在自然语言处理(NLP)和视觉任务中通过其成对注意机制取得了最先进的成果。

虽然变压器有能力关注输入序列的不同部分,但注意力计算的性质是二次的。

Didolkar 等人认为,拥有更压缩的 序列表示可能对泛化是有益的,因为它可以更轻松地重用重新利用,并且具有更少的不相关细节。虽然压缩是好的, 但他们也注意到过多的压缩会损害表达能力。

作者提出了一种解决方案,将计算分为两个流。一个慢流本质上是递归的,而一个快流被参数化为 变压器。虽然这种方法的新颖之处在于引入不同的处理流以保留和处理潜在状态,但它与其他 作品(如Perceiver Mechanism(Jaegle 等人)Grounded Language Learning Fast and Slow(Hill 等人))有相似之处。

下面的示例探索如何利用新的时间潜在瓶颈 机制在 CIFAR-10 数据集上进行图像分类。我们通过自定义 RNNCell 实现 该模型,以实现高性能向量化设计。

注意: 这个示例使用 TensorFlow 2.12.0,必须安装在我们的 系统中。


设置导入

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import mixed_precision
from tensorflow.keras.optimizers import AdamW

import random
from matplotlib import pyplot as plt

# 设置种子以确保可重复性。
keras.utils.set_random_seed(42)

AUTO = tf.data.AUTOTUNE

设置所需配置

我们设置一些在我们设计的管道中需要的配置参数。当前的参数用于 CIFAR10 数据集

该模型还支持 mixed-precision 设置,这会将模型量化为在可用时使用 16位浮点数,同时将某些参数保留为32位以确保数值稳定性。这带来了性能上的好处,因为模型的占用空间显著减少,同时推理时速度提升。

config = {
    "mixed_precision": True,
    "dataset": "cifar10",
    "train_slice": 40_000,
    "batch_size": 2048,
    "buffer_size": 2048 * 2,
    "input_shape": [32, 32, 3],
    "image_size": 48,
    "num_classes": 10,
    "learning_rate": 1e-4,
    "weight_decay": 1e-4,
    "epochs": 30,
    "patch_size": 4,
    "embed_dim": 64,
    "chunk_size": 8,
    "r": 2,
    "num_layers": 4,
    "ffn_drop": 0.2,
    "attn_drop": 0.2,
    "num_heads": 1,
}

if config["mixed_precision"]:
    policy = mixed_precision.Policy("mixed_float16")
    mixed_precision.set_global_policy(policy)
INFO:tensorflow:混合精度兼容性检查(mixed_float16):OK
您的 GPU 在 dtype 策略 mixed_float16 下运行时可能会很快,因为它的计算能力至少为 7.0。您的 GPU: NVIDIA A100-PCIE-40GB,计算能力 8.0

加载 CIFAR-10 数据集

我们将使用 CIFAR10 数据集进行实验。这个数据集 包含 50,000 张图像的训练集,分为 10 个类,标准图像大小为 (32, 32, 3)。 它还有一组独立的 10,000 张具有相似特征的图像。有关数据集的更多信息,可以在数据集的官方网站以及 keras.datasets.cifar10 API 参考中找到。

(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
(x_train, y_train), (x_val, y_val) = (
    (x_train[: config["train_slice"]], y_train[: config["train_slice"]]),
    (x_train[config["train_slice"] :], y_train[config["train_slice"] :]),
)

定义训练和验证/测试管道的数据增强

我们为对数据进行图像增强定义了独立的管道。这一步很重要,使模型对变化更具鲁棒性,帮助其更好地泛化。 我们执行的预处理和增强步骤如下:

  • 重新缩放(训练,测试):这一步骤旨在将所有图像像素值从 [0,255] 范围标准化到 [0,1)。这有助于在训练过程中保持数值稳定性。
  • 调整大小(训练,测试):我们将图像从其原始大小 (32, 32) 调整为 (52, 52)。这是为了适应随机裁剪,并符合论文中给出的数据规格。
  • 随机裁剪(训练):该层随机选择图像的一个大小为 (48, 48) 的裁剪/子区域。
  • 随机翻转(训练):该层随机水平翻转所有图像,保持图像大小不变。
# 构建 `train` 增强管道。
train_augmentation = keras.Sequential(
    [
        layers.Rescaling(1 / 255.0, dtype="float32"),
        layers.Resizing(
            config["input_shape"][0] + 20,
            config["input_shape"][0] + 20,
            dtype="float32",
        ),
        layers.RandomCrop(config["image_size"], config["image_size"], dtype="float32"),
        layers.RandomFlip("horizontal", dtype="float32"),
    ],
    name="train_data_augmentation",
)

# 构建 `val` 和 `test` 数据管道。
test_augmentation = keras.Sequential(
    [
        layers.Rescaling(1 / 255.0, dtype="float32"),
        layers.Resizing(config["image_size"], config["image_size"], dtype="float32"),
    ],
    name="test_data_augmentation",
)

# 我们定义函数而不是简单的 lambda 函数,以通过 [`keras.Sequential`](/api/models/sequential#sequential-class) 运行,以解决此警告:
# (https://github.com/tensorflow/tensorflow/issues/56089)

def train_map_fn(image, label):
    return train_augmentation(image), label

def test_map_fn(image, label):
    return test_augmentation(image), label

将数据集加载到 tf.data.Dataset 对象中

train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_ds = (
    train_ds.map(train_map_fn, num_parallel_calls=AUTO)
    .shuffle(config["buffer_size"])
    .batch(config["batch_size"], num_parallel_calls=AUTO)
    .prefetch(AUTO)
)

val_ds = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_ds = (
    val_ds.map(test_map_fn, num_parallel_calls=AUTO)
    .batch(config["batch_size"], num_parallel_calls=AUTO)
    .prefetch(AUTO)
)

test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_ds = (
    test_ds.map(test_map_fn, num_parallel_calls=AUTO)
    .batch(config["batch_size"], num_parallel_calls=AUTO)
    .prefetch(AUTO)
)

时间潜在瓶颈

摘自论文:

在大脑中,短期记忆和长期记忆以一种专业的方式发展。短期记忆能够快速变化,以反应即时的感官输入和感知。相反,长期记忆变化缓慢,具有高度选择性,并涉及重复的巩固。

受到短期和长期记忆的启发,作者介绍了快速流和慢速流计算。快速流具有高容量的短期记忆,可以快速响应感官输入(变换器)。慢速流具有长期记忆,以较慢速率更新并总结最相关的信息(递归)。

为了实现这个想法,我们需要:

  • 采集一系列数据。
  • 将序列划分为固定大小的块。
  • 快速流在每个块内操作。它提供细粒度的局部信息。
  • 慢速流整合并聚合跨块的信息。它提供粗粒度的远程信息。 快速流和慢速流引入了所谓的信息不对称。两个流通过注意力的瓶颈相互作用。图1显示了模型的架构。
模型架构
图1:模型的架构。(来源:https://arxiv.org/abs/2205.14794)

作者还提出了一个PyTorch风格的伪代码,如算法1所示。

模型的伪代码
算法1:PyTorch风格的伪代码。(来源:https://arxiv.org/abs/2205.14794)

PatchEmbedding

这个自定义的 keras.layers.Layer 用于从图像生成补丁,并使用 keras.layers.Embedding 将其转换为更高维的嵌入空间。补丁操作是通过 keras.layers.Conv2D 实例完成的,而不是传统的 tf.image.extract_patches,以便实现向量化。

一旦图像补丁完成,我们将其重塑,以便获得展平的表示,其中维度的数量是嵌入维度。在这个阶段,我们还向令牌中注入位置信息。

在获得令牌后,我们将其分块。分块操作涉及从嵌入输出中获取固定大小的序列以创建“块”,这些块随后将作为模型的最终输入。

class PatchEmbedding(layers.Layer):
    """图像到补丁嵌入。
    Args:
        image_size (`Tuple[int]`): 输入图像的大小。
        patch_size (`Tuple[int]`): 补丁的大小。
        embed_dim (`int`): 嵌入的维度。
        chunk_size (`int`): 要分块的补丁数量。
    """

    def __init__(
        self,
        image_size,
        patch_size,
        embed_dim,
        chunk_size,
        **kwargs,
    ):
        super().__init__(**kwargs)

        # 计算补丁分辨率。
        patch_resolution = [
            image_size[0] // patch_size[0],
            image_size[1] // patch_size[1],
        ]

        # 存储参数。
        self.image_size = image_size
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        self.patch_resolution = patch_resolution
        self.num_patches = patch_resolution[0] * patch_resolution[1]

        # 定义补丁的位置。
        self.positions = tf.range(start=0, limit=self.num_patches, delta=1)

        # 创建层。
        self.projection = layers.Conv2D(
            filters=embed_dim,
            kernel_size=patch_size,
            strides=patch_size,
            name="projection",
        )
        self.flatten = layers.Reshape(
            target_shape=(-1, embed_dim),
            name="flatten",
        )
        self.position_embedding = layers.Embedding(
            input_dim=self.num_patches,
            output_dim=embed_dim,
            name="position_embedding",
        )
        self.layernorm = keras.layers.LayerNormalization(
            epsilon=1e-5,
            name="layernorm",
        )
        self.chunking_layer = layers.Reshape(
            target_shape=(self.num_patches // chunk_size, chunk_size, embed_dim),
            name="chunking_layer",
        )

    def call(self, inputs):
        # 将输入投影到嵌入维度。
        x = self.projection(inputs)

        # 展平补丁并添加位置嵌入。
        x = self.flatten(x)
        x = x + self.position_embedding(self.positions)

        # 规范化嵌入。
        x = self.layernorm(x)

        # 分块令牌。
        x = self.chunking_layer(x)

        return x

FeedForwardNetwork

这个自定义的 keras.layers.Layer 实例允许我们定义一个通用的FFN以及一个dropout。

class FeedForwardNetwork(layers.Layer):
    """前馈神经网络。
    Args:
        dims (`int`): FFN中的单元数量。
        dropout (`float`): FFN的dropout概率。
    """

    def __init__(self, dims, dropout, **kwargs):
        super().__init__(**kwargs)

        # 创建层。
        self.ffn = keras.Sequential(
            [
                layers.Dense(units=4 * dims, activation=tf.nn.gelu),
                layers.Dense(units=dims),
                layers.Dropout(rate=dropout),
            ],
            name="ffn",
        )
        self.layernorm = layers.LayerNormalization(
            epsilon=1e-5,
            name="layernorm",
        )

    def call(self, inputs):
        # 应用FFN。
        x = self.layernorm(inputs)
        x = inputs + self.ffn(x)
        return x

BaseAttention

这个自定义的 keras.layers.Layer 实例是一个 super/base 类,它封装了一个 keras.layers.MultiHeadAttention 层以及一些其他组件。这为我们模型中的所有注意力层/模块提供了基本的共同功能。

class BaseAttention(layers.Layer):
    """基础注意力模块。
    Args:
        num_heads (`int`): 注意力头的数量。
        key_dim (`int`): 每个注意力头的键的大小。
        dropout (`float`): 注意力模块的丢弃概率。
    """

    def __init__(self, num_heads, key_dim, dropout, **kwargs):
        super().__init__(**kwargs)
        self.multi_head_attention = layers.MultiHeadAttention(
            num_heads=num_heads,
            key_dim=key_dim,
            dropout=dropout,
            name="mha",
        )
        self.query_layernorm = layers.LayerNormalization(
            epsilon=1e-5,
            name="q_layernorm",
        )
        self.key_layernorm = layers.LayerNormalization(
            epsilon=1e-5,
            name="k_layernorm",
        )
        self.value_layernorm = layers.LayerNormalization(
            epsilon=1e-5,
            name="v_layernorm",
        )

        self.attention_scores = None

    def call(self, input_query, key, value):
        # 应用注意力模块。
        query = self.query_layernorm(input_query)
        key = self.key_layernorm(key)
        value = self.value_layernorm(value)
        (attention_outputs, attention_scores) = self.multi_head_attention(
            query=query,
            key=key,
            value=value,
            return_attention_scores=True,
        )

        # 保存注意力得分以供后续可视化。
        self.attention_scores = attention_scores

        # 将输入加到注意力输出中。
        x = input_query + attention_outputs
        return x

带有前馈网络层的Attention

这个自定义的keras.layers.Layer实现结合了BaseAttentionFeedForwardNetwork组件,开发出一个将在模型中重复使用的块。该模块具有高度的可定制性和灵活性,可以对内部层进行更改。

class AttentionWithFFN(layers.Layer):
    """带有前馈网络的注意力。
    Args:
        ffn_dims (`int`): 前馈网络中的单元数量。
        ffn_dropout (`float`): 前馈网络的丢弃概率。
        num_heads (`int`): 注意力头的数量。
        key_dim (`int`): 每个注意力头的键的大小。
        attn_dropout (`float`): 注意力模块的丢弃概率。
    """

    def __init__(
        self,
        ffn_dims,
        ffn_dropout,
        num_heads,
        key_dim,
        attn_dropout,
        **kwargs,
    ):
        super().__init__(**kwargs)
        # 创建层。
        self.attention = BaseAttention(
            num_heads=num_heads,
            key_dim=key_dim,
            dropout=attn_dropout,
            name="base_attn",
        )
        self.ffn = FeedForwardNetwork(
            dims=ffn_dims,
            dropout=ffn_dropout,
            name="ffn",
        )

        self.attention_scores = None

    def call(self, query, key, value):
        # 应用注意力模块。
        x = self.attention(query, key, value)

        # 保存注意力分数以便后续可视化。
        self.attention_scores = self.attention.attention_scores

        # 应用前馈网络。
        x = self.ffn(x)
        return x

自定义RNN单元用于时序潜在瓶颈感知模块

算法1(伪代码)通过for循环描绘了递归。循环确实使得实现更简单,但会影响训练时间。在本节中,我们将自定义递归逻辑封装在CustomRecurrentCell中。这个自定义单元将被封装在Keras RNN API中,使整个代码可以向量化。

这个作为keras.layers.Layer实现的自定义单元是模型逻辑的重要组成部分。单元的功能可以分为两个部分: - 慢流(时序潜在瓶颈):

  • 该模块由单个AttentionWithFFN层组成,该层解析前一个慢流的输出,一个中间隐藏表示(即时序潜在瓶颈中的潜在表示)作为查询,最新的快速流的输出作为键和值。该层也可以被视为交叉注意力层。
  • 快流(感知模块):
  • 该模块由交错的AttentionWithFFN层组成。该流由n层的SelfAttentionCrossAttention以顺序方式组成。
  • 在这里,一些层将分块输入作为查询、键和值(也称为自注意力层)。
  • 其他层将取自时序潜在瓶颈模块的中间状态输出作为查询,同时使用之前的自注意力层的输出作为键和值。
class CustomRecurrentCell(layers.Layer):
    """自定义递归单元。
    Args:
        chunk_size (`int`): 每块中的标记数量。
        r (`int`): 每 **r** 个自注意力一个交叉注意力。
        num_layers (`int`): 层数。
        ffn_dims (`int`): FFN 中的单元数。
        ffn_dropout (`float`): FFN 的丢弃概率。
        num_heads (`int`): 注意力头的数量。
        key_dim (`int`): 每个注意力头的键的大小。
        attn_dropout (`float`): 注意力模块的丢弃概率。
    """

    def __init__(
        self,
        chunk_size,
        r,
        num_layers,
        ffn_dims,
        ffn_dropout,
        num_heads,
        key_dim,
        attn_dropout,
        **kwargs,
    ):
        super().__init__(**kwargs)
        # 保存参数。
        self.chunk_size = chunk_size
        self.r = r
        self.num_layers = num_layers
        self.ffn_dims = ffn_dims
        self.ffn_droput = ffn_dropout
        self.num_heads = num_heads
        self.key_dim = key_dim
        self.attn_dropout = attn_dropout

        # 创建状态大小和输出大小。这对
        # 自定义递归逻辑很重要。
        self.state_size = tf.TensorShape([chunk_size, ffn_dims])
        self.output_size = tf.TensorShape([chunk_size, ffn_dims])

        self.get_attention_scores = False
        self.attention_scores = []

        # 知觉模块
        perceptual_module = list()
        for layer_idx in range(num_layers):
            perceptual_module.append(
                AttentionWithFFN(
                    ffn_dims=ffn_dims,
                    ffn_dropout=ffn_dropout,
                    num_heads=num_heads,
                    key_dim=key_dim,
                    attn_dropout=attn_dropout,
                    name=f"pm_self_attn_{layer_idx}",
                )
            )
            if layer_idx % r == 0:
                perceptual_module.append(
                    AttentionWithFFN(
                        ffn_dims=ffn_dims,
                        ffn_dropout=ffn_dropout,
                        num_heads=num_heads,
                        key_dim=key_dim,
                        attn_dropout=attn_dropout,
                        name=f"pm_cross_attn_ffn_{layer_idx}",
                    )
                )
        self.perceptual_module = perceptual_module

        # 时间潜在瓶颈模块
        self.tlb_module = AttentionWithFFN(
            ffn_dims=ffn_dims,
            ffn_dropout=ffn_dropout,
            num_heads=num_heads,
            key_dim=key_dim,
            attn_dropout=attn_dropout,
            name=f"tlb_cross_attn_ffn",
        )

    def call(self, inputs, states):
        # inputs => (batch, chunk_size, dims)
        # states => [(batch, chunk_size, units)]
        slow_stream = states[0]
        fast_stream = inputs

        for layer_idx, layer in enumerate(self.perceptual_module):
            fast_stream = layer(query=fast_stream, key=fast_stream, value=fast_stream)

            if layer_idx % self.r == 0:
                fast_stream = layer(
                    query=fast_stream, key=slow_stream, value=slow_stream
                )

        slow_stream = self.tlb_module(
            query=slow_stream, key=fast_stream, value=fast_stream
        )

        # 保存注意力得分以便后续可视化。
        if self.get_attention_scores:
            self.attention_scores.append(self.tlb_module.attention_scores)

        return fast_stream, [slow_stream]

TemporalLatentBottleneckModel 封装完整模型

在这里,我们只是将完整模型封装,以便用于训练。

class TemporalLatentBottleneckModel(keras.Model):
    """模型训练器。
    参数:
        patch_layer ([`keras.layers.Layer`](/api/layers/base_layer#layer-class)): 切片层。
        custom_cell ([`keras.layers.Layer`](/api/layers/base_layer#layer-class)): 自定义递归单元。
    """

    def __init__(self, patch_layer, custom_cell, **kwargs):
        super().__init__(**kwargs)
        self.patch_layer = patch_layer
        self.rnn = layers.RNN(custom_cell, name="rnn")
        self.gap = layers.GlobalAveragePooling1D(name="gap")
        self.head = layers.Dense(10, activation="softmax", dtype="float32", name="head")

    def call(self, inputs):
        x = self.patch_layer(inputs)
        x = self.rnn(x)
        x = self.gap(x)
        outputs = self.head(x)
        return outputs

构建模型

为了开始训练,我们现在单独定义组件并将它们作为参数传递给我们的包装类,以准备最终的训练模型。我们定义一个 PatchEmbed 层,以及基于 CustomCell 的 RNN。

# 构建模型。
patch_layer = PatchEmbedding(
    image_size=(config["image_size"], config["image_size"]),
    patch_size=(config["patch_size"], config["patch_size"]),
    embed_dim=config["embed_dim"],
    chunk_size=config["chunk_size"],
)
custom_rnn_cell = CustomRecurrentCell(
    chunk_size=config["chunk_size"],
    r=config["r"],
    num_layers=config["num_layers"],
    ffn_dims=config["embed_dim"],
    ffn_dropout=config["ffn_drop"],
    num_heads=config["num_heads"],
    key_dim=config["embed_dim"],
    attn_dropout=config["attn_drop"],
)
model = TemporalLatentBottleneckModel(
    patch_layer=patch_layer,
    custom_cell=custom_rnn_cell,
)

指标和回调

我们使用 AdamW 优化器,因为它在多个基准任务中已被证明能非常好地执行。从优化的角度来看,它是 keras.optimizers.Adam 优化器的一个版本,同时具有权重衰减。

对于损失函数,我们使用 keras.losses.SparseCategoricalCrossentropy 函数,它利用预测和实际逻辑之间的简单交叉熵。我们还计算数据的准确性作为合理性检查。

optimizer = AdamW(
    learning_rate=config["learning_rate"], weight_decay=config["weight_decay"]
)
model.compile(
    optimizer=optimizer,
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"],
)

使用 model.fit() 训练模型

我们传递训练数据集并运行训练。

history = model.fit(
    train_ds,
    epochs=config["epochs"],
    validation_data=val_ds,
)
Epoch 1/30  
20/20 [==============================] - 104s 3s/step - loss: 2.6284 - accuracy: 0.1010 - val_loss: 2.2835 - val_accuracy: 0.1251  
Epoch 2/30  
20/20 [==============================] - 35s 2s/step - loss: 2.2797 - accuracy: 0.1542 - val_loss: 2.1721 - val_accuracy: 0.1846  
Epoch 3/30  
20/20 [==============================] - 34s 2s/step - loss: 2.1989 - accuracy: 0.1883 - val_loss: 2.1288 - val_accuracy: 0.2241  
Epoch 4/30  
20/20 [==============================] - 34s 2s/step - loss: 2.1267 - accuracy: 0.2192 - val_loss: 2.0919 - val_accuracy: 0.2477  
Epoch 5/30  
20/20 [==============================] - 33s 2s/step - loss: 2.0653 - accuracy: 0.2393 - val_loss: 2.0134 - val_accuracy: 0.2671  
Epoch 6/30  
20/20 [==============================] - 34s 2s/step - loss: 2.0327 - accuracy: 0.2524 - val_loss: 2.0258 - val_accuracy: 0.2665  
Epoch 7/30  
20/20 [==============================] - 34s 2s/step - loss: 2.0047 - accuracy: 0.2598 - val_loss: 1.9871 - val_accuracy: 0.2831  
Epoch 8/30  
20/20 [==============================] - 34s 2s/step - loss: 1.9765 - accuracy: 0.2781 - val_loss: 1.9550 - val_accuracy: 0.2968  
Epoch 9/30  
20/20 [==============================] - 34s 2s/step - loss: 1.9432 - accuracy: 0.2883 - val_loss: 1.9559 - val_accuracy: 0.2969  
Epoch 10/30  
20/20 [==============================] - 33s 2s/step - loss: 1.9062 - accuracy: 0.3020 - val_loss: 1.8967 - val_accuracy: 0.3200  
Epoch 11/30  
20/20 [==============================] - 33s 2s/step - loss: 1.8741 - accuracy: 0.3158 - val_loss: 1.8648 - val_accuracy: 0.3330  
Epoch 12/30  
20/20 [==============================] - 33s 2s/step - loss: 1.8336 - accuracy: 0.3282 - val_loss: 1.7863 - val_accuracy: 0.3464  
Epoch 13/30  
20/20 [==============================] - 33s 2s/step - loss: 1.7931 - accuracy: 0.3434 - val_loss: 1.7364 - val_accuracy: 0.3669  
Epoch 14/30  
20/20 [==============================] - 34s 2s/step - loss: 1.7491 - accuracy: 0.3558 - val_loss: 1.7104 - val_accuracy: 0.3710  
Epoch 15/30  
20/20 [==============================] - 34s 2s/step - loss: 1.7182 - accuracy: 0.3686 - val_loss: 1.6883 - val_accuracy: 0.3866  
Epoch 16/30  
20/20 [==============================] - 33s 2s/step - loss: 1.6819 - accuracy: 0.3790 - val_loss: 1.6493 - val_accuracy: 0.3933  
Epoch 17/30  
20/20 [==============================] - 33s 2s/step - loss: 1.6594 - accuracy: 0.3873 - val_loss: 1.6021 - val_accuracy: 0.4161  
Epoch 18/30  
20/20 [==============================] - 33s 2s/step - loss: 1.6279 - accuracy: 0.3946 - val_loss: 1.5949 - val_accuracy: 0.4170  
Epoch 19/30  
20/20 [==============================] - 34s 2s/step - loss: 1.6127 - accuracy: 0.4015 - val_loss: 1.5672 - val_accuracy: 0.4239  
Epoch 20/30  
20/20 [==============================] - 33s 2s/step - loss: 1.5995 - accuracy: 0.4079 - val_loss: 1.5795 - val_accuracy: 0.4223  
Epoch 21/30  
20/20 [==============================] - 34s 2s/step - loss: 1.5809 - accuracy: 0.4167 - val_loss: 1.5294 - val_accuracy: 0.4390  
Epoch 22/30  
20/20 [==============================] - 34s 2s/step - loss: 1.5572 - accuracy: 0.4254 - val_loss: 1.5192 - val_accuracy: 0.4455  
Epoch 23/30  
20/20 [==============================] - 33s 2s/step - loss: 1.5468 - accuracy: 0.4291 - val_loss: 1.5243 - val_accuracy: 0.4424  
Epoch 24/30  
20/20 [==============================] - 34s 2s/step - loss: 1.5347 - accuracy: 0.4335 - val_loss: 1.4920 - val_accuracy: 0.4532  
Epoch 25/30  
20/20 [==============================] - 33s 2s/step - loss: 1.5245 - accuracy: 0.4387 - val_loss: 1.4805 - val_accuracy: 0.4584  
Epoch 26/30  
20/20 [==============================] - 33s 2s/step - loss: 1.5057 - accuracy: 0.4469 - val_loss: 1.4754 - val_accuracy: 0.4592  
Epoch 27/30  
20/20 [==============================] - 34s 2s/step - loss: 1.5013 - accuracy: 0.4457 - val_loss: 1.4688 - val_accuracy: 0.4619  
Epoch 28/30  
20/20 [==============================] - 33s 2s/step - loss: 1.4852 - accuracy: 0.4548 - val_loss: 1.4543 - val_accuracy: 0.4704  
Epoch 29/30  
20/20 [==============================] - 34s 2s/step - loss: 1.4728 - accuracy: 0.4570 - val_loss: 1.4437 - val_accuracy: 0.4751  
Epoch 30/30  
20/20 [==============================] - 34s 2s/step - loss: 1.4652 - accuracy: 0.4606 - val_loss: 1.4546 - val_accuracy: 0.4726  

可视化训练指标

model.fit() 将返回一个 history 对象,该对象存储了训练运行期间生成的指标值(但它是短暂的,需要手动保存)。

我们现在展示训练集和验证集的损失和准确率曲线。

plt.plot(history.history["loss"], label="loss")
plt.plot(history.history["val_loss"], label="val_loss")
plt.legend()
plt.show()

plt.plot(history.history["accuracy"], label="accuracy")
plt.plot(history.history["val_accuracy"], label="val_accuracy")
plt.legend()
plt.show()

png

png


可视化时间潜在瓶颈的注意力图

既然我们已经训练好了模型,是时候进行一些可视化了。快速流 (Transformers)处理一块令牌。慢速流处理每一块并 关注对任务有用的令牌。

在这一部分,我们可视化慢速流的注意力图。这是通过 从 TLB 层在每个块的交集中提取注意力分数并 将其存储在 RNN 的状态中实现的。接着对其进行“膨胀”并返回 这些值。

def score_to_viz(chunk_score):
    # 获取最受关注的令牌
    chunk_viz = tf.math.reduce_max(chunk_score, axis=-2)
    # 计算跨头的均值
    chunk_viz = tf.math.reduce_mean(chunk_viz, axis=1)
    return chunk_viz


# 从测试数据集中获取一批图像和标签
images, labels = next(iter(test_ds))

# 将 get_attn_scores 标志设置为 True
model.rnn.cell.get_attention_scores = True

# 使用测试图像运行模型并抓取
# 注意力分数。
outputs = model(images)
list_chunk_scores = model.rnn.cell.attention_scores

# 处理注意力分数以进行可视化
list_chunk_viz = [score_to_viz(x) for x in list_chunk_scores]
chunk_viz = tf.concat(list_chunk_viz[1:], axis=-1)
chunk_viz = tf.reshape(
    chunk_viz,
    (
        config["batch_size"],
        config["image_size"] // config["patch_size"],
        config["image_size"] // config["patch_size"],
        1,
    ),
)
upsampled_heat_map = layers.UpSampling2D(
    size=(4, 4), interpolation="bilinear", dtype="float32"
)(chunk_viz)

运行以下代码片段以获取不同的图像及其注意力图。

# 随机选择一张图像
index = random.randint(0, config["batch_size"])
orig_image = images[index]
overlay_image = upsampled_heat_map[index, ..., 0]

# 绘制可视化
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))

ax[0].imshow(orig_image)
ax[0].set_title("原始:")
ax[0].axis("off")

image = ax[1].imshow(orig_image)
ax[1].imshow(
    overlay_image,
    cmap="inferno",
    alpha=0.6,
    extent=image.get_extent(),
)
ax[1].set_title("TLB 注意力:")

plt.show()

png


结论

本示例展示了时间潜在瓶颈机制的实现。示例突出了使用历史状态的压缩和存储 以时间潜在瓶颈的形式以及来自感知模块的定期更新作为一种有效的方法。

在原始论文中,作者进行了广泛的测试,涵盖了从监督图像分类到在强化学习中应用的不同 模态。

虽然我们仅展示了一种将此机制应用于图像分类的方法,但它也可以通过最小的更改扩展到其他模态。

注意: 在构建此示例时,我们没有官方代码可供参考。这 意味着我们的实现是受论文启发的,并不声称是完全的复现。有关训练过程的更多详细信息,可以前往 我们的 GitHub 仓库


特别感谢

感谢 Aniket Didolkar(第一作者)和 Anirudh Goyal(第三作者) 审阅我们的工作。

我们还要感谢 PyImageSearch 提供 Colab Pro 帐户和 JarvisLabs.ai 提供 GPU 积分。