作者: Aritra Roy Gosthipaty, Suvaditya Mukherjee
创建日期: 2023/03/12
最后修改日期: 2023/03/12
描述: 使用时间潜在瓶颈网络进行图像分类。
一个简单的递归神经网络(RNN)表现出强烈的归纳偏差,倾向于学习
时间压缩表示。 公式 1 显示了递归公式,
其中 h_t
是整个输入序列 x
的压缩表示(一个单一向量)。
公式 1: 递归方程。 (来源: Aritra 和 Suvaditya) |
另一方面,变压器(Vaswani 等人)对学习时间压缩表示几乎没有归纳偏差。 变压器在自然语言处理(NLP)和视觉任务中通过其成对注意机制取得了最先进的成果。
虽然变压器有能力关注输入序列的不同部分,但注意力计算的性质是二次的。
Didolkar 等人认为,拥有更压缩的 序列表示可能对泛化是有益的,因为它可以更轻松地重用和重新利用,并且具有更少的不相关细节。虽然压缩是好的, 但他们也注意到过多的压缩会损害表达能力。
作者提出了一种解决方案,将计算分为两个流。一个慢流本质上是递归的,而一个快流被参数化为 变压器。虽然这种方法的新颖之处在于引入不同的处理流以保留和处理潜在状态,但它与其他 作品(如Perceiver Mechanism(Jaegle 等人) 和Grounded Language Learning Fast and Slow(Hill 等人))有相似之处。
下面的示例探索如何利用新的时间潜在瓶颈
机制在 CIFAR-10 数据集上进行图像分类。我们通过自定义 RNNCell
实现
该模型,以实现高性能和向量化设计。
注意: 这个示例使用 TensorFlow 2.12.0
,必须安装在我们的
系统中。
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import mixed_precision
from tensorflow.keras.optimizers import AdamW
import random
from matplotlib import pyplot as plt
# 设置种子以确保可重复性。
keras.utils.set_random_seed(42)
AUTO = tf.data.AUTOTUNE
我们设置一些在我们设计的管道中需要的配置参数。当前的参数用于 CIFAR10 数据集。
该模型还支持 mixed-precision
设置,这会将模型量化为在可用时使用
16位
浮点数,同时将某些参数保留为32位
以确保数值稳定性。这带来了性能上的好处,因为模型的占用空间显著减少,同时推理时速度提升。
config = {
"mixed_precision": True,
"dataset": "cifar10",
"train_slice": 40_000,
"batch_size": 2048,
"buffer_size": 2048 * 2,
"input_shape": [32, 32, 3],
"image_size": 48,
"num_classes": 10,
"learning_rate": 1e-4,
"weight_decay": 1e-4,
"epochs": 30,
"patch_size": 4,
"embed_dim": 64,
"chunk_size": 8,
"r": 2,
"num_layers": 4,
"ffn_drop": 0.2,
"attn_drop": 0.2,
"num_heads": 1,
}
if config["mixed_precision"]:
policy = mixed_precision.Policy("mixed_float16")
mixed_precision.set_global_policy(policy)
INFO:tensorflow:混合精度兼容性检查(mixed_float16):OK
您的 GPU 在 dtype 策略 mixed_float16 下运行时可能会很快,因为它的计算能力至少为 7.0。您的 GPU: NVIDIA A100-PCIE-40GB,计算能力 8.0
我们将使用 CIFAR10 数据集进行实验。这个数据集
包含 50,000
张图像的训练集,分为 10
个类,标准图像大小为
(32, 32, 3)
。
它还有一组独立的 10,000
张具有相似特征的图像。有关数据集的更多信息,可以在数据集的官方网站以及 keras.datasets.cifar10
API 参考中找到。
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
(x_train, y_train), (x_val, y_val) = (
(x_train[: config["train_slice"]], y_train[: config["train_slice"]]),
(x_train[config["train_slice"] :], y_train[config["train_slice"] :]),
)
我们为对数据进行图像增强定义了独立的管道。这一步很重要,使模型对变化更具鲁棒性,帮助其更好地泛化。 我们执行的预处理和增强步骤如下:
重新缩放
(训练,测试):这一步骤旨在将所有图像像素值从 [0,255]
范围标准化到 [0,1)
。这有助于在训练过程中保持数值稳定性。调整大小
(训练,测试):我们将图像从其原始大小 (32, 32) 调整为 (52, 52)。这是为了适应随机裁剪,并符合论文中给出的数据规格。随机裁剪
(训练):该层随机选择图像的一个大小为 (48, 48)
的裁剪/子区域。随机翻转
(训练):该层随机水平翻转所有图像,保持图像大小不变。# 构建 `train` 增强管道。
train_augmentation = keras.Sequential(
[
layers.Rescaling(1 / 255.0, dtype="float32"),
layers.Resizing(
config["input_shape"][0] + 20,
config["input_shape"][0] + 20,
dtype="float32",
),
layers.RandomCrop(config["image_size"], config["image_size"], dtype="float32"),
layers.RandomFlip("horizontal", dtype="float32"),
],
name="train_data_augmentation",
)
# 构建 `val` 和 `test` 数据管道。
test_augmentation = keras.Sequential(
[
layers.Rescaling(1 / 255.0, dtype="float32"),
layers.Resizing(config["image_size"], config["image_size"], dtype="float32"),
],
name="test_data_augmentation",
)
# 我们定义函数而不是简单的 lambda 函数,以通过 [`keras.Sequential`](/api/models/sequential#sequential-class) 运行,以解决此警告:
# (https://github.com/tensorflow/tensorflow/issues/56089)
def train_map_fn(image, label):
return train_augmentation(image), label
def test_map_fn(image, label):
return test_augmentation(image), label
tf.data.Dataset
对象中np.ndarray
实例移入 tf.data.Dataset
实例.map()
应用增强.shuffle()
打乱数据集.batch()
进行批处理.prefetch()
启用批处理预取train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_ds = (
train_ds.map(train_map_fn, num_parallel_calls=AUTO)
.shuffle(config["buffer_size"])
.batch(config["batch_size"], num_parallel_calls=AUTO)
.prefetch(AUTO)
)
val_ds = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_ds = (
val_ds.map(test_map_fn, num_parallel_calls=AUTO)
.batch(config["batch_size"], num_parallel_calls=AUTO)
.prefetch(AUTO)
)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_ds = (
test_ds.map(test_map_fn, num_parallel_calls=AUTO)
.batch(config["batch_size"], num_parallel_calls=AUTO)
.prefetch(AUTO)
)
摘自论文:
在大脑中,短期记忆和长期记忆以一种专业的方式发展。短期记忆能够快速变化,以反应即时的感官输入和感知。相反,长期记忆变化缓慢,具有高度选择性,并涉及重复的巩固。
受到短期和长期记忆的启发,作者介绍了快速流和慢速流计算。快速流具有高容量的短期记忆,可以快速响应感官输入(变换器)。慢速流具有长期记忆,以较慢速率更新并总结最相关的信息(递归)。
为了实现这个想法,我们需要:
图1:模型的架构。(来源:https://arxiv.org/abs/2205.14794) |
作者还提出了一个PyTorch风格的伪代码,如算法1所示。
算法1:PyTorch风格的伪代码。(来源:https://arxiv.org/abs/2205.14794) |
PatchEmbedding
层这个自定义的 keras.layers.Layer
用于从图像生成补丁,并使用 keras.layers.Embedding
将其转换为更高维的嵌入空间。补丁操作是通过 keras.layers.Conv2D
实例完成的,而不是传统的 tf.image.extract_patches
,以便实现向量化。
一旦图像补丁完成,我们将其重塑,以便获得展平的表示,其中维度的数量是嵌入维度。在这个阶段,我们还向令牌中注入位置信息。
在获得令牌后,我们将其分块。分块操作涉及从嵌入输出中获取固定大小的序列以创建“块”,这些块随后将作为模型的最终输入。
class PatchEmbedding(layers.Layer):
"""图像到补丁嵌入。
Args:
image_size (`Tuple[int]`): 输入图像的大小。
patch_size (`Tuple[int]`): 补丁的大小。
embed_dim (`int`): 嵌入的维度。
chunk_size (`int`): 要分块的补丁数量。
"""
def __init__(
self,
image_size,
patch_size,
embed_dim,
chunk_size,
**kwargs,
):
super().__init__(**kwargs)
# 计算补丁分辨率。
patch_resolution = [
image_size[0] // patch_size[0],
image_size[1] // patch_size[1],
]
# 存储参数。
self.image_size = image_size
self.patch_size = patch_size
self.embed_dim = embed_dim
self.patch_resolution = patch_resolution
self.num_patches = patch_resolution[0] * patch_resolution[1]
# 定义补丁的位置。
self.positions = tf.range(start=0, limit=self.num_patches, delta=1)
# 创建层。
self.projection = layers.Conv2D(
filters=embed_dim,
kernel_size=patch_size,
strides=patch_size,
name="projection",
)
self.flatten = layers.Reshape(
target_shape=(-1, embed_dim),
name="flatten",
)
self.position_embedding = layers.Embedding(
input_dim=self.num_patches,
output_dim=embed_dim,
name="position_embedding",
)
self.layernorm = keras.layers.LayerNormalization(
epsilon=1e-5,
name="layernorm",
)
self.chunking_layer = layers.Reshape(
target_shape=(self.num_patches // chunk_size, chunk_size, embed_dim),
name="chunking_layer",
)
def call(self, inputs):
# 将输入投影到嵌入维度。
x = self.projection(inputs)
# 展平补丁并添加位置嵌入。
x = self.flatten(x)
x = x + self.position_embedding(self.positions)
# 规范化嵌入。
x = self.layernorm(x)
# 分块令牌。
x = self.chunking_layer(x)
return x
FeedForwardNetwork
层这个自定义的 keras.layers.Layer
实例允许我们定义一个通用的FFN以及一个dropout。
class FeedForwardNetwork(layers.Layer):
"""前馈神经网络。
Args:
dims (`int`): FFN中的单元数量。
dropout (`float`): FFN的dropout概率。
"""
def __init__(self, dims, dropout, **kwargs):
super().__init__(**kwargs)
# 创建层。
self.ffn = keras.Sequential(
[
layers.Dense(units=4 * dims, activation=tf.nn.gelu),
layers.Dense(units=dims),
layers.Dropout(rate=dropout),
],
name="ffn",
)
self.layernorm = layers.LayerNormalization(
epsilon=1e-5,
name="layernorm",
)
def call(self, inputs):
# 应用FFN。
x = self.layernorm(inputs)
x = inputs + self.ffn(x)
return x
BaseAttention
层这个自定义的 keras.layers.Layer
实例是一个 super
/base
类,它封装了一个 keras.layers.MultiHeadAttention
层以及一些其他组件。这为我们模型中的所有注意力层/模块提供了基本的共同功能。
class BaseAttention(layers.Layer):
"""基础注意力模块。
Args:
num_heads (`int`): 注意力头的数量。
key_dim (`int`): 每个注意力头的键的大小。
dropout (`float`): 注意力模块的丢弃概率。
"""
def __init__(self, num_heads, key_dim, dropout, **kwargs):
super().__init__(**kwargs)
self.multi_head_attention = layers.MultiHeadAttention(
num_heads=num_heads,
key_dim=key_dim,
dropout=dropout,
name="mha",
)
self.query_layernorm = layers.LayerNormalization(
epsilon=1e-5,
name="q_layernorm",
)
self.key_layernorm = layers.LayerNormalization(
epsilon=1e-5,
name="k_layernorm",
)
self.value_layernorm = layers.LayerNormalization(
epsilon=1e-5,
name="v_layernorm",
)
self.attention_scores = None
def call(self, input_query, key, value):
# 应用注意力模块。
query = self.query_layernorm(input_query)
key = self.key_layernorm(key)
value = self.value_layernorm(value)
(attention_outputs, attention_scores) = self.multi_head_attention(
query=query,
key=key,
value=value,
return_attention_scores=True,
)
# 保存注意力得分以供后续可视化。
self.attention_scores = attention_scores
# 将输入加到注意力输出中。
x = input_query + attention_outputs
return x
Attention
这个自定义的keras.layers.Layer
实现结合了BaseAttention
和FeedForwardNetwork
组件,开发出一个将在模型中重复使用的块。该模块具有高度的可定制性和灵活性,可以对内部层进行更改。
class AttentionWithFFN(layers.Layer):
"""带有前馈网络的注意力。
Args:
ffn_dims (`int`): 前馈网络中的单元数量。
ffn_dropout (`float`): 前馈网络的丢弃概率。
num_heads (`int`): 注意力头的数量。
key_dim (`int`): 每个注意力头的键的大小。
attn_dropout (`float`): 注意力模块的丢弃概率。
"""
def __init__(
self,
ffn_dims,
ffn_dropout,
num_heads,
key_dim,
attn_dropout,
**kwargs,
):
super().__init__(**kwargs)
# 创建层。
self.attention = BaseAttention(
num_heads=num_heads,
key_dim=key_dim,
dropout=attn_dropout,
name="base_attn",
)
self.ffn = FeedForwardNetwork(
dims=ffn_dims,
dropout=ffn_dropout,
name="ffn",
)
self.attention_scores = None
def call(self, query, key, value):
# 应用注意力模块。
x = self.attention(query, key, value)
# 保存注意力分数以便后续可视化。
self.attention_scores = self.attention.attention_scores
# 应用前馈网络。
x = self.ffn(x)
return x
算法1(伪代码)通过for循环描绘了递归。循环确实使得实现更简单,但会影响训练时间。在本节中,我们将自定义递归逻辑封装在CustomRecurrentCell
中。这个自定义单元将被封装在Keras RNN API中,使整个代码可以向量化。
这个作为keras.layers.Layer
实现的自定义单元是模型逻辑的重要组成部分。单元的功能可以分为两个部分:
- 慢流(时序潜在瓶颈):
AttentionWithFFN
层组成,该层解析前一个慢流的输出,一个中间隐藏表示(即时序潜在瓶颈中的潜在表示)作为查询,最新的快速流的输出作为键和值。该层也可以被视为交叉注意力层。AttentionWithFFN
层组成。该流由n层的SelfAttention
和CrossAttention
以顺序方式组成。class CustomRecurrentCell(layers.Layer):
"""自定义递归单元。
Args:
chunk_size (`int`): 每块中的标记数量。
r (`int`): 每 **r** 个自注意力一个交叉注意力。
num_layers (`int`): 层数。
ffn_dims (`int`): FFN 中的单元数。
ffn_dropout (`float`): FFN 的丢弃概率。
num_heads (`int`): 注意力头的数量。
key_dim (`int`): 每个注意力头的键的大小。
attn_dropout (`float`): 注意力模块的丢弃概率。
"""
def __init__(
self,
chunk_size,
r,
num_layers,
ffn_dims,
ffn_dropout,
num_heads,
key_dim,
attn_dropout,
**kwargs,
):
super().__init__(**kwargs)
# 保存参数。
self.chunk_size = chunk_size
self.r = r
self.num_layers = num_layers
self.ffn_dims = ffn_dims
self.ffn_droput = ffn_dropout
self.num_heads = num_heads
self.key_dim = key_dim
self.attn_dropout = attn_dropout
# 创建状态大小和输出大小。这对
# 自定义递归逻辑很重要。
self.state_size = tf.TensorShape([chunk_size, ffn_dims])
self.output_size = tf.TensorShape([chunk_size, ffn_dims])
self.get_attention_scores = False
self.attention_scores = []
# 知觉模块
perceptual_module = list()
for layer_idx in range(num_layers):
perceptual_module.append(
AttentionWithFFN(
ffn_dims=ffn_dims,
ffn_dropout=ffn_dropout,
num_heads=num_heads,
key_dim=key_dim,
attn_dropout=attn_dropout,
name=f"pm_self_attn_{layer_idx}",
)
)
if layer_idx % r == 0:
perceptual_module.append(
AttentionWithFFN(
ffn_dims=ffn_dims,
ffn_dropout=ffn_dropout,
num_heads=num_heads,
key_dim=key_dim,
attn_dropout=attn_dropout,
name=f"pm_cross_attn_ffn_{layer_idx}",
)
)
self.perceptual_module = perceptual_module
# 时间潜在瓶颈模块
self.tlb_module = AttentionWithFFN(
ffn_dims=ffn_dims,
ffn_dropout=ffn_dropout,
num_heads=num_heads,
key_dim=key_dim,
attn_dropout=attn_dropout,
name=f"tlb_cross_attn_ffn",
)
def call(self, inputs, states):
# inputs => (batch, chunk_size, dims)
# states => [(batch, chunk_size, units)]
slow_stream = states[0]
fast_stream = inputs
for layer_idx, layer in enumerate(self.perceptual_module):
fast_stream = layer(query=fast_stream, key=fast_stream, value=fast_stream)
if layer_idx % self.r == 0:
fast_stream = layer(
query=fast_stream, key=slow_stream, value=slow_stream
)
slow_stream = self.tlb_module(
query=slow_stream, key=fast_stream, value=fast_stream
)
# 保存注意力得分以便后续可视化。
if self.get_attention_scores:
self.attention_scores.append(self.tlb_module.attention_scores)
return fast_stream, [slow_stream]
TemporalLatentBottleneckModel
封装完整模型在这里,我们只是将完整模型封装,以便用于训练。
class TemporalLatentBottleneckModel(keras.Model):
"""模型训练器。
参数:
patch_layer ([`keras.layers.Layer`](/api/layers/base_layer#layer-class)): 切片层。
custom_cell ([`keras.layers.Layer`](/api/layers/base_layer#layer-class)): 自定义递归单元。
"""
def __init__(self, patch_layer, custom_cell, **kwargs):
super().__init__(**kwargs)
self.patch_layer = patch_layer
self.rnn = layers.RNN(custom_cell, name="rnn")
self.gap = layers.GlobalAveragePooling1D(name="gap")
self.head = layers.Dense(10, activation="softmax", dtype="float32", name="head")
def call(self, inputs):
x = self.patch_layer(inputs)
x = self.rnn(x)
x = self.gap(x)
outputs = self.head(x)
return outputs
为了开始训练,我们现在单独定义组件并将它们作为参数传递给我们的包装类,以准备最终的训练模型。我们定义一个 PatchEmbed
层,以及基于 CustomCell
的 RNN。
# 构建模型。
patch_layer = PatchEmbedding(
image_size=(config["image_size"], config["image_size"]),
patch_size=(config["patch_size"], config["patch_size"]),
embed_dim=config["embed_dim"],
chunk_size=config["chunk_size"],
)
custom_rnn_cell = CustomRecurrentCell(
chunk_size=config["chunk_size"],
r=config["r"],
num_layers=config["num_layers"],
ffn_dims=config["embed_dim"],
ffn_dropout=config["ffn_drop"],
num_heads=config["num_heads"],
key_dim=config["embed_dim"],
attn_dropout=config["attn_drop"],
)
model = TemporalLatentBottleneckModel(
patch_layer=patch_layer,
custom_cell=custom_rnn_cell,
)
我们使用 AdamW
优化器,因为它在多个基准任务中已被证明能非常好地执行。从优化的角度来看,它是 keras.optimizers.Adam
优化器的一个版本,同时具有权重衰减。
对于损失函数,我们使用 keras.losses.SparseCategoricalCrossentropy
函数,它利用预测和实际逻辑之间的简单交叉熵。我们还计算数据的准确性作为合理性检查。
optimizer = AdamW(
learning_rate=config["learning_rate"], weight_decay=config["weight_decay"]
)
model.compile(
optimizer=optimizer,
loss="sparse_categorical_crossentropy",
metrics=["accuracy"],
)
model.fit()
训练模型我们传递训练数据集并运行训练。
history = model.fit(
train_ds,
epochs=config["epochs"],
validation_data=val_ds,
)
Epoch 1/30
20/20 [==============================] - 104s 3s/step - loss: 2.6284 - accuracy: 0.1010 - val_loss: 2.2835 - val_accuracy: 0.1251
Epoch 2/30
20/20 [==============================] - 35s 2s/step - loss: 2.2797 - accuracy: 0.1542 - val_loss: 2.1721 - val_accuracy: 0.1846
Epoch 3/30
20/20 [==============================] - 34s 2s/step - loss: 2.1989 - accuracy: 0.1883 - val_loss: 2.1288 - val_accuracy: 0.2241
Epoch 4/30
20/20 [==============================] - 34s 2s/step - loss: 2.1267 - accuracy: 0.2192 - val_loss: 2.0919 - val_accuracy: 0.2477
Epoch 5/30
20/20 [==============================] - 33s 2s/step - loss: 2.0653 - accuracy: 0.2393 - val_loss: 2.0134 - val_accuracy: 0.2671
Epoch 6/30
20/20 [==============================] - 34s 2s/step - loss: 2.0327 - accuracy: 0.2524 - val_loss: 2.0258 - val_accuracy: 0.2665
Epoch 7/30
20/20 [==============================] - 34s 2s/step - loss: 2.0047 - accuracy: 0.2598 - val_loss: 1.9871 - val_accuracy: 0.2831
Epoch 8/30
20/20 [==============================] - 34s 2s/step - loss: 1.9765 - accuracy: 0.2781 - val_loss: 1.9550 - val_accuracy: 0.2968
Epoch 9/30
20/20 [==============================] - 34s 2s/step - loss: 1.9432 - accuracy: 0.2883 - val_loss: 1.9559 - val_accuracy: 0.2969
Epoch 10/30
20/20 [==============================] - 33s 2s/step - loss: 1.9062 - accuracy: 0.3020 - val_loss: 1.8967 - val_accuracy: 0.3200
Epoch 11/30
20/20 [==============================] - 33s 2s/step - loss: 1.8741 - accuracy: 0.3158 - val_loss: 1.8648 - val_accuracy: 0.3330
Epoch 12/30
20/20 [==============================] - 33s 2s/step - loss: 1.8336 - accuracy: 0.3282 - val_loss: 1.7863 - val_accuracy: 0.3464
Epoch 13/30
20/20 [==============================] - 33s 2s/step - loss: 1.7931 - accuracy: 0.3434 - val_loss: 1.7364 - val_accuracy: 0.3669
Epoch 14/30
20/20 [==============================] - 34s 2s/step - loss: 1.7491 - accuracy: 0.3558 - val_loss: 1.7104 - val_accuracy: 0.3710
Epoch 15/30
20/20 [==============================] - 34s 2s/step - loss: 1.7182 - accuracy: 0.3686 - val_loss: 1.6883 - val_accuracy: 0.3866
Epoch 16/30
20/20 [==============================] - 33s 2s/step - loss: 1.6819 - accuracy: 0.3790 - val_loss: 1.6493 - val_accuracy: 0.3933
Epoch 17/30
20/20 [==============================] - 33s 2s/step - loss: 1.6594 - accuracy: 0.3873 - val_loss: 1.6021 - val_accuracy: 0.4161
Epoch 18/30
20/20 [==============================] - 33s 2s/step - loss: 1.6279 - accuracy: 0.3946 - val_loss: 1.5949 - val_accuracy: 0.4170
Epoch 19/30
20/20 [==============================] - 34s 2s/step - loss: 1.6127 - accuracy: 0.4015 - val_loss: 1.5672 - val_accuracy: 0.4239
Epoch 20/30
20/20 [==============================] - 33s 2s/step - loss: 1.5995 - accuracy: 0.4079 - val_loss: 1.5795 - val_accuracy: 0.4223
Epoch 21/30
20/20 [==============================] - 34s 2s/step - loss: 1.5809 - accuracy: 0.4167 - val_loss: 1.5294 - val_accuracy: 0.4390
Epoch 22/30
20/20 [==============================] - 34s 2s/step - loss: 1.5572 - accuracy: 0.4254 - val_loss: 1.5192 - val_accuracy: 0.4455
Epoch 23/30
20/20 [==============================] - 33s 2s/step - loss: 1.5468 - accuracy: 0.4291 - val_loss: 1.5243 - val_accuracy: 0.4424
Epoch 24/30
20/20 [==============================] - 34s 2s/step - loss: 1.5347 - accuracy: 0.4335 - val_loss: 1.4920 - val_accuracy: 0.4532
Epoch 25/30
20/20 [==============================] - 33s 2s/step - loss: 1.5245 - accuracy: 0.4387 - val_loss: 1.4805 - val_accuracy: 0.4584
Epoch 26/30
20/20 [==============================] - 33s 2s/step - loss: 1.5057 - accuracy: 0.4469 - val_loss: 1.4754 - val_accuracy: 0.4592
Epoch 27/30
20/20 [==============================] - 34s 2s/step - loss: 1.5013 - accuracy: 0.4457 - val_loss: 1.4688 - val_accuracy: 0.4619
Epoch 28/30
20/20 [==============================] - 33s 2s/step - loss: 1.4852 - accuracy: 0.4548 - val_loss: 1.4543 - val_accuracy: 0.4704
Epoch 29/30
20/20 [==============================] - 34s 2s/step - loss: 1.4728 - accuracy: 0.4570 - val_loss: 1.4437 - val_accuracy: 0.4751
Epoch 30/30
20/20 [==============================] - 34s 2s/step - loss: 1.4652 - accuracy: 0.4606 - val_loss: 1.4546 - val_accuracy: 0.4726
model.fit()
将返回一个 history
对象,该对象存储了训练运行期间生成的指标值(但它是短暂的,需要手动保存)。
我们现在展示训练集和验证集的损失和准确率曲线。
plt.plot(history.history["loss"], label="loss")
plt.plot(history.history["val_loss"], label="val_loss")
plt.legend()
plt.show()
plt.plot(history.history["accuracy"], label="accuracy")
plt.plot(history.history["val_accuracy"], label="val_accuracy")
plt.legend()
plt.show()
既然我们已经训练好了模型,是时候进行一些可视化了。快速流 (Transformers)处理一块令牌。慢速流处理每一块并 关注对任务有用的令牌。
在这一部分,我们可视化慢速流的注意力图。这是通过 从 TLB 层在每个块的交集中提取注意力分数并 将其存储在 RNN 的状态中实现的。接着对其进行“膨胀”并返回 这些值。
def score_to_viz(chunk_score):
# 获取最受关注的令牌
chunk_viz = tf.math.reduce_max(chunk_score, axis=-2)
# 计算跨头的均值
chunk_viz = tf.math.reduce_mean(chunk_viz, axis=1)
return chunk_viz
# 从测试数据集中获取一批图像和标签
images, labels = next(iter(test_ds))
# 将 get_attn_scores 标志设置为 True
model.rnn.cell.get_attention_scores = True
# 使用测试图像运行模型并抓取
# 注意力分数。
outputs = model(images)
list_chunk_scores = model.rnn.cell.attention_scores
# 处理注意力分数以进行可视化
list_chunk_viz = [score_to_viz(x) for x in list_chunk_scores]
chunk_viz = tf.concat(list_chunk_viz[1:], axis=-1)
chunk_viz = tf.reshape(
chunk_viz,
(
config["batch_size"],
config["image_size"] // config["patch_size"],
config["image_size"] // config["patch_size"],
1,
),
)
upsampled_heat_map = layers.UpSampling2D(
size=(4, 4), interpolation="bilinear", dtype="float32"
)(chunk_viz)
运行以下代码片段以获取不同的图像及其注意力图。
# 随机选择一张图像
index = random.randint(0, config["batch_size"])
orig_image = images[index]
overlay_image = upsampled_heat_map[index, ..., 0]
# 绘制可视化
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
ax[0].imshow(orig_image)
ax[0].set_title("原始:")
ax[0].axis("off")
image = ax[1].imshow(orig_image)
ax[1].imshow(
overlay_image,
cmap="inferno",
alpha=0.6,
extent=image.get_extent(),
)
ax[1].set_title("TLB 注意力:")
plt.show()
本示例展示了时间潜在瓶颈机制的实现。示例突出了使用历史状态的压缩和存储 以时间潜在瓶颈的形式以及来自感知模块的定期更新作为一种有效的方法。
在原始论文中,作者进行了广泛的测试,涵盖了从监督图像分类到在强化学习中应用的不同 模态。
虽然我们仅展示了一种将此机制应用于图像分类的方法,但它也可以通过最小的更改扩展到其他模态。
注意: 在构建此示例时,我们没有官方代码可供参考。这 意味着我们的实现是受论文启发的,并不声称是完全的复现。有关训练过程的更多详细信息,可以前往 我们的 GitHub 仓库。
感谢 Aniket Didolkar(第一作者)和 Anirudh Goyal(第三作者) 审阅我们的工作。
我们还要感谢 PyImageSearch 提供 Colab Pro 帐户和 JarvisLabs.ai 提供 GPU 积分。