作者: Aritra Roy Gosthipaty, Ayush Thakur(共同贡献)
创建日期: 2022/01/12
最后修改日期: 2024/01/15
描述: 一种基于Transformer的视频分类架构。
视频是图像的序列。假设您手头有一个图像表示模型(CNN、ViT等)和一个序列模型(RNN、LSTM等)。我们要求您调整模型以进行视频分类。最简单的方法就是对单个帧应用图像模型,使用序列模型学习图像特征的序列,然后在学习到的序列表示上应用分类头。Keras示例使用CNN-RNN架构的视频分类详细解释了这种方法。或者,您还可以像Keras示例基于Transformer的视频分类中所示的那样,构建一个混合的基于Transformer的视频分类模型。
在本示例中,我们最小化实现了ViViT: A Video Vision Transformer由Arnab等人提出,这是一个纯基于Transformer的视频分类模型。作者提出了一种新颖的嵌入方案和多种Transformer变体来建模视频剪辑。为了简化,我们实现了嵌入方案和Transformer架构的一个变体。
此示例需要medmnist
包,可以通过运行下面的代码单元安装。
!pip install -qq medmnist
import os
import io
import imageio
import medmnist
import ipywidgets
import numpy as np
import tensorflow as tf # 仅用于数据预处理
import keras
from keras import layers, ops
# 设置随机种子以确保可重复性
SEED = 42
os.environ["TF_CUDNN_DETERMINISTIC"] = "1"
keras.utils.set_random_seed(SEED)
超参数通过超参数搜索选择。您可以在“结论”部分了解更多关于该过程的信息。
# 数据
DATASET_NAME = "organmnist3d"
BATCH_SIZE = 32
AUTO = tf.data.AUTOTUNE
INPUT_SHAPE = (28, 28, 28, 1)
NUM_CLASSES = 11
# 优化器
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-5
# 训练
EPOCHS = 60
# 管道嵌入
PATCH_SIZE = (8, 8, 8)
NUM_PATCHES = (INPUT_SHAPE[0] // PATCH_SIZE[0]) ** 2
# ViViT架构
LAYER_NORM_EPS = 1e-6
PROJECTION_DIM = 128
NUM_HEADS = 8
NUM_LAYERS = 8
在我们的示例中,我们使用MedMNIST v2:用于2D和3D生物医学图像分类的大规模轻量级基准数据集。视频轻量且易于训练。
def download_and_prepare_dataset(data_info: dict):
"""下载数据集的工具函数。
参数:
data_info (dict): 数据集元数据。
"""
data_path = keras.utils.get_file(origin=data_info["url"], md5_hash=data_info["MD5"])
with np.load(data_path) as data:
# 获取视频
train_videos = data["train_images"]
valid_videos = data["val_images"]
test_videos = data["test_images"]
# 获取标签
train_labels = data["train_labels"].flatten()
valid_labels = data["val_labels"].flatten()
test_labels = data["test_labels"].flatten()
return (
(train_videos, train_labels),
(valid_videos, valid_labels),
(test_videos, test_labels),
)
# 获取数据集的元数据
info = medmnist.INFO[DATASET_NAME]
# 获取数据集
prepared_dataset = download_and_prepare_dataset(info)
(train_videos, train_labels) = prepared_dataset[0]
(valid_videos, valid_labels) = prepared_dataset[1]
(test_videos, test_labels) = prepared_dataset[2]
tf.data
管道def preprocess(frames: tf.Tensor, label: tf.Tensor):
"""预处理帧张量并解析标签。"""
# 预处理图像
frames = tf.image.convert_image_dtype(
frames[
..., tf.newaxis
], # 新轴是为了便于进一步与Conv3D层处理
tf.float32,
)
# 解析标签
label = tf.cast(label, tf.float32)
return frames, label
def prepare_dataloader(
videos: np.ndarray,
labels: np.ndarray,
loader_type: str = "train",
batch_size: int = BATCH_SIZE,
):
"""准备数据加载器的实用函数。"""
dataset = tf.data.Dataset.from_tensor_slices((videos, labels))
if loader_type == "train":
dataset = dataset.shuffle(BATCH_SIZE * 2)
dataloader = (
dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
.batch(batch_size)
.prefetch(tf.data.AUTOTUNE)
)
return dataloader
trainloader = prepare_dataloader(train_videos, train_labels, "train")
validloader = prepare_dataloader(valid_videos, valid_labels, "valid")
testloader = prepare_dataloader(test_videos, test_labels, "test")
在ViTs中,图像被划分为小块,然后被空间展平,这一过程被称为标记化。对于视频,可以对单独的帧重复此过程。均匀帧采样是作者建议的一种标记化方案,其中我们从视频片段中采样帧,并进行简单的ViT标记化。
均匀帧采样 来源 |
Tubelet嵌入在捕获视频的时间信息方面有所不同。 首先,我们从视频中提取体积——这些体积包含了帧的小块和时间信息。然后将体积展平以构建视频标记。
Tubelet嵌入 来源 |
class TubeletEmbedding(layers.Layer):
def __init__(self, embed_dim, patch_size, **kwargs):
super().__init__(**kwargs)
self.projection = layers.Conv3D(
filters=embed_dim,
kernel_size=patch_size,
strides=patch_size,
padding="VALID",
)
self.flatten = layers.Reshape(target_shape=(-1, embed_dim))
def call(self, videos):
projected_patches = self.projection(videos)
flattened_patches = self.flatten(projected_patches)
return flattened_patches
这一层向编码的视频标记添加位置信息。
class PositionalEncoder(layers.Layer):
def __init__(self, embed_dim, **kwargs):
super().__init__(**kwargs)
self.embed_dim = embed_dim
def build(self, input_shape):
_, num_tokens, _ = input_shape
self.position_embedding = layers.Embedding(
input_dim=num_tokens, output_dim=self.embed_dim
)
self.positions = ops.arange(0, num_tokens, 1)
def call(self, encoded_tokens):
# 编码位置并将其添加到编码的标记
encoded_positions = self.position_embedding(self.positions)
encoded_tokens = encoded_tokens + encoded_positions
return encoded_tokens
作者提出了四种变体的视觉变换器:
在这个例子中,我们将实现时空注意力模型以简化。以下代码片段深受使用视觉变换器的图像分类的启发。还可以参考ViViT的官方仓库,其中包含所有变体,使用JAX实现。
def create_vivit_classifier(
tubelet_embedder,
positional_encoder,
input_shape=INPUT_SHAPE,
transformer_layers=NUM_LAYERS,
num_heads=NUM_HEADS,
embed_dim=PROJECTION_DIM,
layer_norm_eps=LAYER_NORM_EPS,
num_classes=NUM_CLASSES,
):
# 获取输入层
inputs = layers.Input(shape=input_shape)
# 创建小块。
patches = tubelet_embedder(inputs)
# 编码小块。
encoded_patches = positional_encoder(patches)
# 创建多个Transformer块的层。
for _ in range(transformer_layers):
# 层归一化和MHSA
x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
attention_output = layers.MultiHeadAttention(
num_heads=num_heads, key_dim=embed_dim // num_heads, dropout=0.1
)(x1, x1)
# 跳跃连接
x2 = layers.Add()([attention_output, encoded_patches])
# 层归一化和MLP
x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
x3 = keras.Sequential(
[
layers.Dense(units=embed_dim * 4, activation=ops.gelu),
layers.Dense(units=embed_dim, activation=ops.gelu),
]
)(x3)
# 跳跃连接
encoded_patches = layers.Add()([x3, x2])
# 层归一化和全局平均池化。
representation = layers.LayerNormalization(epsilon=layer_norm_eps)(encoded_patches)
representation = layers.GlobalAvgPool1D()(representation)
# 分类输出。
outputs = layers.Dense(units=num_classes, activation="softmax")(representation)
# 创建Keras模型。
model = keras.Model(inputs=inputs, outputs=outputs)
return model
def run_experiment():
# 初始化模型
model = create_vivit_classifier(
tubelet_embedder=TubeletEmbedding(
embed_dim=PROJECTION_DIM, patch_size=PATCH_SIZE
),
positional_encoder=PositionalEncoder(embed_dim=PROJECTION_DIM),
)
# 使用优化器、损失函数和指标编译模型。
optimizer = keras.optimizers.Adam(learning_rate=LEARNING_RATE)
model.compile(
optimizer=optimizer,
loss="sparse_categorical_crossentropy",
metrics=[
keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),
],
)
# 训练模型。
_ = model.fit(trainloader, epochs=EPOCHS, validation_data=validloader)
_, accuracy, top_5_accuracy = model.evaluate(testloader)
print(f"测试准确率: {round(accuracy * 100, 2)}%")
print(f"测试前 5 准确率: {round(top_5_accuracy * 100, 2)}%")
return model
model = run_experiment()
测试准确率: 76.72%
测试前 5 准确率: 97.54%
NUM_SAMPLES_VIZ = 25
testsamples, labels = next(iter(testloader))
testsamples, labels = testsamples[:NUM_SAMPLES_VIZ], labels[:NUM_SAMPLES_VIZ]
ground_truths = []
preds = []
videos = []
for i, (testsample, label) in enumerate(zip(testsamples, labels)):
# 生成 gif
testsample = np.reshape(testsample.numpy(), (-1, 28, 28))
with io.BytesIO() as gif:
imageio.mimsave(gif, (testsample * 255).astype("uint8"), "GIF", fps=5)
videos.append(gif.getvalue())
# 获取模型预测
output = model.predict(ops.expand_dims(testsample, axis=0))[0]
pred = np.argmax(output, axis=0)
ground_truths.append(label.numpy().astype("int"))
preds.append(pred)
def make_box_for_grid(image_widget, fit):
"""创建一个 VBox 来容纳标题/图像以演示 option_fit 值。
来源: https://ipywidgets.readthedocs.io/en/latest/examples/Widget%20Styling.html
"""
# 制作标题
if fit is not None:
fit_str = "'{}'".format(fit)
else:
fit_str = str(fit)
h = ipywidgets.HTML(value="" + str(fit_str) + "")
# 制作带有图像小部件的绿色框
boxb = ipywidgets.widgets.Box()
boxb.children = [image_widget]
# 组成一个垂直框
vb = ipywidgets.widgets.VBox()
vb.layout.align_items = "center"
vb.children = [h, boxb]
return vb
boxes = []
for i in range(NUM_SAMPLES_VIZ):
ib = ipywidgets.widgets.Image(value=videos[i], width=100, height=100)
true_class = info["label"][str(ground_truths[i])]
pred_class = info["label"][str(preds[i])]
caption = f"T: {true_class} | P: {pred_class}"
boxes.append(make_box_for_grid(ib, caption))
ipywidgets.widgets.GridBox(
boxes, layout=ipywidgets.widgets.Layout(grid_template_columns="repeat(5, 200px)")
)
通过基础实现,我们在测试数据集上达到了 ~79-80% 的前 1 准确率。
本教程中使用的超参数是通过运行超参数搜索确定的,使用了 W&B Sweeps。 您可以在 这里找到我们的搜索结果,以及 这里有关结果的快速分析。
为了进一步改进,您可以考虑以下几点:
我们要感谢 Anurag Arnab (ViViT 的第一作者)提供的有益讨论。我们对 Weights and Biases 项目表示感谢,它帮助我们解决了 GPU 额度的问题。
您可以使用托管在 Hugging Face Hub 上的训练模型,并在 Hugging Face Spaces 上尝试演示。