作者: Aritra Roy Gosthipaty, Ritwik Raha, Shivalika Singh
创建日期: 2022/02/24
最后修改日期: 2022/10/15
描述: ShiftViT 的最小实现。
视觉变换器 (ViTs) 在变换器与计算机视觉 (CV) 的交叉领域引发了一波研究热潮。
由于变换器模块中的多头自注意力机制,ViTs 可以同时建模长程和短程依赖关系。许多研究人员认为 ViTs 的成功纯粹归因于注意力层,他们很少考虑 ViT 模型的其他部分。
在学术论文 当平移操作遇见视觉变换器:注意力机制的一种极其简单的替代方案 中,作者提出用一种无参数操作代替注意力操作,从而揭示 ViTs 成功的奥秘。他们用平移操作替代了注意力操作。
在本示例中,我们最小化实现了该论文,并与作者的 官方实现 进行了密切对齐。
本示例需要 TensorFlow 2.9 或更高版本,以及 TensorFlow Addons,可以使用以下命令安装:
!pip install -qq -U tensorflow-addons
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa
import pathlib
import glob
# 为了重现性设置随机种子
SEED = 42
keras.utils.set_random_seed(SEED)
这些是我们为实验选择的超参数。 请随意调整它们。
class Config(object):
# 数据
batch_size = 256
buffer_size = batch_size * 2
input_shape = (32, 32, 3)
num_classes = 10
# 增强
image_size = 48
# 架构
patch_size = 4
projected_dim = 96
num_shift_blocks_per_stages = [2, 4, 8, 2]
epsilon = 1e-5
stochastic_depth_rate = 0.2
mlp_dropout_rate = 0.2
num_div = 12
shift_pixel = 1
mlp_expand_ratio = 2
# 优化器
lr_start = 1e-5
lr_max = 1e-3
weight_decay = 1e-4
# 训练
epochs = 100
# 推理
label_map = {
0: "飞机",
1: "汽车",
2: "鸟",
3: "猫",
4: "鹿",
5: "狗",
6: "青蛙",
7: "马",
8: "船",
9: "卡车",
}
tf_ds_batch_size = 20
config = Config()
我们使用 CIFAR-10 数据集进行实验。
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
(x_train, y_train), (x_val, y_val) = (
(x_train[:40000], y_train[:40000]),
(x_train[40000:], y_train[40000:]),
)
print(f"训练样本: {len(x_train)}")
print(f"验证样本: {len(x_val)}")
print(f"测试样本: {len(x_test)}")
AUTO = tf.data.AUTOTUNE
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_ds = train_ds.shuffle(config.buffer_size).batch(config.batch_size).prefetch(AUTO)
val_ds = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_ds = val_ds.batch(config.batch_size).prefetch(AUTO)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_ds = test_ds.batch(config.batch_size).prefetch(AUTO)
从 https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz 下载数据
170498071/170498071 [==============================] - 3s 0us/step
训练样本: 40000
验证样本: 10000
测试样本: 10000
增强管道包括:
注意: 图像数据增强层在推理时不应用
数据变换。这意味着
当这些层在 training=False
时,它们的行为不同。有关更多详细信息,请参阅
文档。
def get_augmentation_model():
"""构建数据增强模型。"""
data_augmentation = keras.Sequential(
[
layers.Resizing(config.input_shape[0] + 20, config.input_shape[0] + 20),
layers.RandomCrop(config.image_size, config.image_size),
layers.RandomFlip("horizontal"),
layers.Rescaling(1 / 255.0),
]
)
return data_augmentation
–
在本节中,我们构建了在 ShiftViT论文中提出的架构。
图1:ShiftViT的整个架构。 |
来源 |
如图1所示,该架构的灵感来源于 Swin Transformer: 使用移动窗口的分层视觉变换器。 在这里,作者提出了一个具有4个阶段的模块化架构。每个阶段在其 自己的空间大小上工作,创建了一个分层架构。
输入图像大小为HxWx3
,被分割成大小为4x4
的不重叠块。
这通过patchify层完成,结果是特征大小为48
(4x4x3
)的单独标记。每个阶段包括两个部分:
我们在接下来的内容中详细讨论阶段和模块。
注意:与官方实现相比,我们重新构造了一些关键组件,以更好地适应Keras API。
图2:从模型到移位块。 |
ShiftViT架构中的每个阶段都包含一个移位块,如图2所示。
图3:Shift ViT块。来源 |
如图3所示,移位块包括以下内容:
MLP块旨在成为密集连接层的堆栈。
class MLP(layers.Layer):
"""获取每个移位块的MLP层。
参数:
mlp_expand_ratio (int):第一个特征图扩展的比率。
mlp_dropout_rate (float):丢弃的比率。
"""
def __init__(self, mlp_expand_ratio, mlp_dropout_rate, **kwargs):
super().__init__(**kwargs)
self.mlp_expand_ratio = mlp_expand_ratio
self.mlp_dropout_rate = mlp_dropout_rate
def build(self, input_shape):
input_channels = input_shape[-1]
initial_filters = int(self.mlp_expand_ratio * input_channels)
self.mlp = keras.Sequential(
[
layers.Dense(
units=initial_filters,
activation=tf.nn.gelu,
),
layers.Dropout(rate=self.mlp_dropout_rate),
layers.Dense(units=input_channels),
layers.Dropout(rate=self.mlp_dropout_rate),
]
)
def call(self, x):
x = self.mlp(x)
return x
随机深度是一种正则化技术,它随机丢弃一组 层。在推理过程中,层保持不变。它与Dropout非常相似,但它作用于一块层而不是单个节点。
class DropPath(layers.Layer):
"""Drop Path,也称为随机深度层。
参考:
- https://keras.io/examples/vision/cct/#stochastic-depth-for-regularization
- github.com:rwightman/pytorch-image-models
"""
def __init__(self, drop_path_prob, **kwargs):
super().__init__(**kwargs)
self.drop_path_prob = drop_path_prob
def call(self, x, training=False):
if training:
keep_prob = 1 - self.drop_path_prob
shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
random_tensor = tf.floor(random_tensor)
return (x / keep_prob) * random_tensor
return x
本文中最重要的操作是移位操作。在本节中, 我们描述移位操作,并与作者提供的原始实现进行比较。
假设一个通用的特征图具有形状[N, H, W, C]
。在这里,我们选择一个
num_div
参数,该参数决定通道的划分大小。前4个划分
向左、右、上和下方向移位(1像素)。其余的分割保持不变。经过部分移位后,移位的通道被填充,溢出的
像素被切掉。这完成了部分移位操作。
在原始实现中,代码大致如下:
out[:, g * 0:g * 1, :, :-1] = x[:, g * 0:g * 1, :, 1:] # 向左移动
out[:, g * 1:g * 2, :, 1:] = x[:, g * 1:g * 2, :, :-1] # 向右移动
out[:, g * 2:g * 3, :-1, :] = x[:, g * 2:g * 3, 1:, :] # 向上移动
out[:, g * 3:g * 4, 1:, :] = x[:, g * 3:g * 4, :-1, :] # 向下移动
out[:, g * 4:, :, :] = x[:, g * 4:, :, :] # 不移动
中间训练过程。因此,我们采用以下程序:
num_div
参数分割通道。图4:TensorFlow 风格的移位 |
整个过程在图4中进行了说明。
class ShiftViTBlock(layers.Layer):
"""一个单位 ShiftViT 块
参数:
shift_pixel (int): 移动的像素数。默认为1。
mlp_expand_ratio (int): MLP特征扩展的比例。
默认为2。
mlp_dropout_rate (float): MLP中使用的丢弃率。
num_div (int): 特征图通道的分割数。
总共将移动 4/num_div 的通道。默认为12。
epsilon (float): Epsilon 常量。
drop_path_prob (float): 丢弃路径的丢弃概率。
"""
def __init__(
self,
epsilon,
drop_path_prob,
mlp_dropout_rate,
num_div=12,
shift_pixel=1,
mlp_expand_ratio=2,
**kwargs,
):
super().__init__(**kwargs)
self.shift_pixel = shift_pixel
self.mlp_expand_ratio = mlp_expand_ratio
self.mlp_dropout_rate = mlp_dropout_rate
self.num_div = num_div
self.epsilon = epsilon
self.drop_path_prob = drop_path_prob
def build(self, input_shape):
self.H = input_shape[1]
self.W = input_shape[2]
self.C = input_shape[3]
self.layer_norm = layers.LayerNormalization(epsilon=self.epsilon)
self.drop_path = (
DropPath(drop_path_prob=self.drop_path_prob)
if self.drop_path_prob > 0.0
else layers.Activation("linear")
)
self.mlp = MLP(
mlp_expand_ratio=self.mlp_expand_ratio,
mlp_dropout_rate=self.mlp_dropout_rate,
)
def get_shift_pad(self, x, mode):
"""根据所选的模式移动通道。"""
if mode == "left":
offset_height = 0
offset_width = 0
target_height = 0
target_width = self.shift_pixel
elif mode == "right":
offset_height = 0
offset_width = self.shift_pixel
target_height = 0
target_width = self.shift_pixel
elif mode == "up":
offset_height = 0
offset_width = 0
target_height = self.shift_pixel
target_width = 0
else:
offset_height = self.shift_pixel
offset_width = 0
target_height = self.shift_pixel
target_width = 0
crop = tf.image.crop_to_bounding_box(
x,
offset_height=offset_height,
offset_width=offset_width,
target_height=self.H - target_height,
target_width=self.W - target_width,
)
shift_pad = tf.image.pad_to_bounding_box(
crop,
offset_height=offset_height,
offset_width=offset_width,
target_height=self.H,
target_width=self.W,
)
return shift_pad
def call(self, x, training=False):
# 分割特征图
x_splits = tf.split(x, num_or_size_splits=self.C // self.num_div, axis=-1)
# 移动特征图
x_splits[0] = self.get_shift_pad(x_splits[0], mode="left")
x_splits[1] = self.get_shift_pad(x_splits[1], mode="right")
x_splits[2] = self.get_shift_pad(x_splits[2], mode="up")
x_splits[3] = self.get_shift_pad(x_splits[3], mode="down")
# 连接移动和未移动的特征图
x = tf.concat(x_splits, axis=-1)
# 添加残差连接
shortcut = x
x = shortcut + self.drop_path(self.mlp(self.layer_norm(x)), training=training)
return x
图5:架构中的 Shift 块。 来源 |
每个阶段的架构中都有如图5所示的移位块。每个这样的块包含一个可变数量的堆叠 ShiftViT 块(如前面部分所构建)。
移位块后面是一个 PatchMerging 层,用于缩小特征输入。PatchMerging 层有助于模型的金字塔结构。
这个层合并两个相邻的标记。这个层有助于在空间上缩小特征并在通道上增加特征。我们使用 Conv2D 层来合并补丁。
class PatchMerging(layers.Layer):
"""补丁合并层。
Args:
epsilon (float): epsilon常数。
"""
def __init__(self, epsilon, **kwargs):
super().__init__(**kwargs)
self.epsilon = epsilon
def build(self, input_shape):
filters = 2 * input_shape[-1]
self.reduction = layers.Conv2D(
filters=filters, kernel_size=2, strides=2, padding="same", use_bias=False
)
self.layer_norm = layers.LayerNormalization(epsilon=self.epsilon)
def call(self, x):
# 在特征图上应用补丁合并算法
x = self.layer_norm(x)
x = self.reduction(x)
return x
每个阶段将具有数量可变的堆叠 ShiftViT 块,如论文中所建议的。这是一个通用层,将包含堆叠的移位 vit 块和拼接层。将这两种操作(移位 ViT 块和拼接)结合在一起是我们为了更好的代码可重用性所做的设计选择。
# 注意:这一层在模型的不同阶段会有不同的堆叠深度
class StackedShiftBlocks(layers.Layer):
"""包含堆叠的 ShiftViTBlocks 的层。
参数:
epsilon (float): epsilon 常数。
mlp_dropout_rate (float): MLP 块中使用的 dropout 率。
num_shift_blocks (int): 此阶段的移位 vit 块数量。
stochastic_depth_rate (float): 选择的最大 dropout 路径率。
is_merge (boolean): 一个标志,用于确定是否在移位 vit 块之后使用 Patch Merge 层。
num_div (int): 特征图的通道数划分。默认为 12。
shift_pixel (int): 移位的像素数。默认为 1。
mlp_expand_ratio (int): MLP 的初始密集层扩展的比例。默认为 2。
"""
def __init__(
self,
epsilon,
mlp_dropout_rate,
num_shift_blocks,
stochastic_depth_rate,
is_merge,
num_div=12,
shift_pixel=1,
mlp_expand_ratio=2,
**kwargs,
):
super().__init__(**kwargs)
self.epsilon = epsilon
self.mlp_dropout_rate = mlp_dropout_rate
self.num_shift_blocks = num_shift_blocks
self.stochastic_depth_rate = stochastic_depth_rate
self.is_merge = is_merge
self.num_div = num_div
self.shift_pixel = shift_pixel
self.mlp_expand_ratio = mlp_expand_ratio
def build(self, input_shapes):
# 计算随机深度概率。
# 参考: https://keras.io/examples/vision/cct/#the-final-cct-model
dpr = [
x
for x in np.linspace(
start=0, stop=self.stochastic_depth_rate, num=self.num_shift_blocks
)
]
# 将移位块构建为 ShiftViT 块的列表
self.shift_blocks = list()
for num in range(self.num_shift_blocks):
self.shift_blocks.append(
ShiftViTBlock(
num_div=self.num_div,
epsilon=self.epsilon,
drop_path_prob=dpr[num],
mlp_dropout_rate=self.mlp_dropout_rate,
shift_pixel=self.shift_pixel,
mlp_expand_ratio=self.mlp_expand_ratio,
)
)
if self.is_merge:
self.patch_merge = PatchMerging(epsilon=self.epsilon)
def call(self, x, training=False):
for shift_block in self.shift_blocks:
x = shift_block(x, training=training)
if self.is_merge:
x = self.patch_merge(x)
return x
# 由于这是一个自定义层,我们需要重写 get_config()
# 以便模型可以在训练后方便地保存和加载
def get_config(self):
config = super().get_config()
config.update(
{
"epsilon": self.epsilon,
"mlp_dropout_rate": self.mlp_dropout_rate,
"num_shift_blocks": self.num_shift_blocks,
"stochastic_depth_rate": self.stochastic_depth_rate,
"is_merge": self.is_merge,
"num_div": self.num_div,
"shift_pixel": self.shift_pixel,
"mlp_expand_ratio": self.mlp_expand_ratio,
}
)
return config
构建 ShiftViT 自定义模型。
class ShiftViTModel(keras.Model):
"""ShiftViT模型。
参数:
data_augmentation (keras.Model): 数据增强模型。
projected_dim (int): 图像块投影的维度。
patch_size (int): 图像的块大小。
num_shift_blocks_per_stages (list[int]): 每个阶段的所有移位块的数量列表。
epsilon (float): epsilon常数。
mlp_dropout_rate (float): 在MLP块中使用的丢弃率。
stochastic_depth_rate (float): 最大丢弃率概率。
num_div (int): 特征图通道的划分数量。默认为12。
shift_pixel (int): 移位的像素数量。默认为1。
mlp_expand_ratio (int): 初始MLP密集层扩展到的比例。默认为2。
"""
def __init__(
self,
data_augmentation,
projected_dim,
patch_size,
num_shift_blocks_per_stages,
epsilon,
mlp_dropout_rate,
stochastic_depth_rate,
num_div=12,
shift_pixel=1,
mlp_expand_ratio=2,
**kwargs,
):
super().__init__(**kwargs)
self.data_augmentation = data_augmentation
self.patch_projection = layers.Conv2D(
filters=projected_dim,
kernel_size=patch_size,
strides=patch_size,
padding="same",
)
self.stages = list()
for index, num_shift_blocks in enumerate(num_shift_blocks_per_stages):
if index == len(num_shift_blocks_per_stages) - 1:
# This is the last stage, do not use the patch merge here.
is_merge = False
else:
is_merge = True
# Build the stages.
self.stages.append(
StackedShiftBlocks(
epsilon=epsilon,
mlp_dropout_rate=mlp_dropout_rate,
num_shift_blocks=num_shift_blocks,
stochastic_depth_rate=stochastic_depth_rate,
is_merge=is_merge,
num_div=num_div,
shift_pixel=shift_pixel,
mlp_expand_ratio=mlp_expand_ratio,
)
)
self.global_avg_pool = layers.GlobalAveragePooling2D()
self.classifier = layers.Dense(config.num_classes)
def get_config(self):
config = super().get_config()
config.update(
{
"data_augmentation": self.data_augmentation,
"patch_projection": self.patch_projection,
"stages": self.stages,
"global_avg_pool": self.global_avg_pool,
"classifier": self.classifier,
}
)
return config
def _calculate_loss(self, data, training=False):
(images, labels) = data
# Augment the images
augmented_images = self.data_augmentation(images, training=training)
# Create patches and project the pathces.
projected_patches = self.patch_projection(augmented_images)
# Pass through the stages
x = projected_patches
for stage in self.stages:
x = stage(x, training=training)
# Get the logits.
x = self.global_avg_pool(x)
logits = self.classifier(x)
# Calculate the loss and return it.
total_loss = self.compiled_loss(labels, logits)
return total_loss, labels, logits
def train_step(self, inputs):
with tf.GradientTape() as tape:
total_loss, labels, logits = self._calculate_loss(
data=inputs, training=True
)
# Apply gradients.
train_vars = [
self.data_augmentation.trainable_variables,
self.patch_projection.trainable_variables,
self.global_avg_pool.trainable_variables,
self.classifier.trainable_variables,
]
train_vars = train_vars + [stage.trainable_variables for stage in self.stages]
# Optimize the gradients.
grads = tape.gradient(total_loss, train_vars)
trainable_variable_list = []
for (grad, var) in zip(grads, train_vars):
for g, v in zip(grad, var):
trainable_variable_list.append((g, v))
self.optimizer.apply_gradients(trainable_variable_list)
# Update the metrics
self.compiled_metrics.update_state(labels, logits)
return {m.name: m.result() for m in self.metrics}
def test_step(self, data):
_, labels, logits = self._calculate_loss(data=data, training=False)
# Update the metrics
self.compiled_metrics.update_state(labels, logits)
return {m.name: m.result() for m in self.metrics}
def call(self, images):
augmented_images = self.data_augmentation(images)
x = self.patch_projection(augmented_images)
for stage in self.stages:
x = stage(x, training=False)
x = self.global_avg_pool(x)
logits = self.classifier(x)
return logits
model = ShiftViTModel(
data_augmentation=get_augmentation_model(),
projected_dim=config.projected_dim,
patch_size=config.patch_size,
num_shift_blocks_per_stages=config.num_shift_blocks_per_stages,
epsilon=config.epsilon,
mlp_dropout_rate=config.mlp_dropout_rate,
stochastic_depth_rate=config.stochastic_depth_rate,
num_div=config.num_div,
shift_pixel=config.shift_pixel,
mlp_expand_ratio=config.mlp_expand_ratio,
)
在许多实验中,我们希望通过缓慢增加学习率来热身模型,然后通过缓慢减少学习率来冷却模型。在预热余弦衰减中,学习率在线性增加的预热步骤后,以余弦衰减的方式衰减。
# 一些代码来源于:
# https://www.kaggle.com/ashusma/training-rfcx-tensorflow-tpu-effnet-b2.
class WarmUpCosine(keras.optimizers.schedules.LearningRateSchedule):
"""一个使用预热余弦衰减调度的学习率调度类。"""
def __init__(self, lr_start, lr_max, warmup_steps, total_steps):
"""
参数:
lr_start: 初始学习率
lr_max: 在预热步骤中学习率应增加到的最大学习率
warmup_steps: 模型热身的步骤数
total_steps: 模型训练的总步骤数
"""
super().__init__()
self.lr_start = lr_start
self.lr_max = lr_max
self.warmup_steps = warmup_steps
self.total_steps = total_steps
self.pi = tf.constant(np.pi)
def __call__(self, step):
# 检查总步骤数是否大于预热步骤。如果不是,则抛出值错误。
if self.total_steps < self.warmup_steps:
raise ValueError(
f"总步骤数 {self.total_steps} 必须"
+ f"大于或等于预热步骤数 {self.warmup_steps}。"
)
# `cos_annealed_lr` 是一个图形,从初始步骤增加到 1,直到预热步骤。之后,该图形在最后步骤标记处衰减到 -1。
cos_annealed_lr = tf.cos(
self.pi
* (tf.cast(step, tf.float32) - self.warmup_steps)
/ tf.cast(self.total_steps - self.warmup_steps, tf.float32)
)
# 将 `cos_annealed_lr` 图形的均值移到 1。现在该图形从 0 变为 2。用 0.5 进行归一化,使其现在从 0 变为 1。通过归一化图形,我们用 `lr_max` 进行缩放,使其从 0 变为 `lr_max`
learning_rate = 0.5 * self.lr_max * (1 + cos_annealed_lr)
# 检查 warmup_steps 是否大于 0。
if self.warmup_steps > 0:
# 检查 lr_max 是否大于 lr_start。如果不是,则抛出值错误。
if self.lr_max < self.lr_start:
raise ValueError(
f"lr_start {self.lr_start} 必须小于或"
+ f"等于 lr_max {self.lr_max}。"
)
# 计算学习率在预热调度中应增加的斜率。斜率的公式是 m = ((b-a)/steps)
slope = (self.lr_max - self.lr_start) / self.warmup_steps
# 使用直线的公式 (y = mx+c) 构建预热调度
warmup_rate = slope * tf.cast(step, tf.float32) + self.lr_start
# 当当前步骤小于预热步骤时,获取直线图。当当前步骤大于预热步骤时,获取缩放后的余弦图。
learning_rate = tf.where(
step < self.warmup_steps, warmup_rate, learning_rate
)
# 当当前步骤大于总步骤时,返回 0,否则返回计算出的图形。
return tf.where(
step > self.total_steps, 0.0, learning_rate, name="learning_rate"
)
def get_config(self):
config = {
"lr_start": self.lr_start,
"lr_max": self.lr_max,
"total_steps": self.total_steps,
"warmup_steps": self.warmup_steps,
}
return config
# pass sample data to the model so that input shape is available at the time of
# saving the model
sample_ds, _ = next(iter(train_ds))
model(sample_ds, training=False)
# Get the total number of steps for training.
total_steps = int((len(x_train) / config.batch_size) * config.epochs)
# Calculate the number of steps for warmup.
warmup_epoch_percentage = 0.15
warmup_steps = int(total_steps * warmup_epoch_percentage)
# Initialize the warmupcosine schedule.
scheduled_lrs = WarmUpCosine(
lr_start=1e-5,
lr_max=1e-3,
warmup_steps=warmup_steps,
total_steps=total_steps,
)
# Get the optimizer.
optimizer = tfa.optimizers.AdamW(
learning_rate=scheduled_lrs, weight_decay=config.weight_decay
)
# Compile and pretrain the model.
model.compile(
optimizer=optimizer,
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[
keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),
],
)
# Train the model
history = model.fit(
train_ds,
epochs=config.epochs,
validation_data=val_ds,
callbacks=[
keras.callbacks.EarlyStopping(
monitor="val_accuracy",
patience=5,
mode="auto",
)
],
)
# Evaluate the model with the test dataset.
print("TESTING")
loss, acc_top1, acc_top5 = model.evaluate(test_ds)
print(f"Loss: {loss:0.2f}")
print(f"Top 1 test accuracy: {acc_top1*100:0.2f}%")
print(f"Top 5 test accuracy: {acc_top5*100:0.2f}%")
Epoch 1/100
157/157 [==============================] - 72s 332ms/step - loss: 2.3844 - accuracy: 0.1444 - top-5-accuracy: 0.6051 - val_loss: 2.0984 - val_accuracy: 0.2610 - val_top-5-accuracy: 0.7638
Epoch 2/100
157/157 [==============================] - 49s 314ms/step - loss: 1.9457 - accuracy: 0.2893 - top-5-accuracy: 0.8103 - val_loss: 1.9459 - val_accuracy: 0.3356 - val_top-5-accuracy: 0.8614
Epoch 3/100
157/157 [==============================] - 50s 316ms/step - loss: 1.7093 - accuracy: 0.3810 - top-5-accuracy: 0.8761 - val_loss: 1.5349 - val_accuracy: 0.4585 - val_top-5-accuracy: 0.9045
Epoch 4/100
157/157 [==============================] - 49s 315ms/step - loss: 1.5473 - accuracy: 0.4374 - top-5-accuracy: 0.9090 - val_loss: 1.4257 - val_accuracy: 0.4862 - val_top-5-accuracy: 0.9298
Epoch 5/100
157/157 [==============================] - 50s 316ms/step - loss: 1.4316 - accuracy: 0.4816 - top-5-accuracy: 0.9243 - val_loss: 1.4032 - val_accuracy: 0.5092 - val_top-5-accuracy: 0.9362
Epoch 6/100
157/157 [==============================] - 50s 316ms/step - loss: 1.3588 - accuracy: 0.5131 - top-5-accuracy: 0.9333 - val_loss: 1.2893 - val_accuracy: 0.5411 - val_top-5-accuracy: 0.9457
Epoch 7/100
157/157 [==============================] - 50s 316ms/step - loss: 1.2894 - accuracy: 0.5385 - top-5-accuracy: 0.9410 - val_loss: 1.2922 - val_accuracy: 0.5416 - val_top-5-accuracy: 0.9432
Epoch 8/100
157/157 [==============================] - 49s 315ms/step - loss: 1.2388 - accuracy: 0.5568 - top-5-accuracy: 0.9468 - val_loss: 1.2100 - val_accuracy: 0.5733 - val_top-5-accuracy: 0.9545
Epoch 9/100
157/157 [==============================] - 49s 315ms/step - loss: 1.2043 - accuracy: 0.5698 - top-5-accuracy: 0.9491 - val_loss: 1.2166 - val_accuracy: 0.5675 - val_top-5-accuracy: 0.9520
Epoch 10/100
157/157 [==============================] - 49s 315ms/step - loss: 1.1694 - accuracy: 0.5861 - top-5-accuracy: 0.9528 - val_loss: 1.1738 - val_accuracy: 0.5883 - val_top-5-accuracy: 0.9541
Epoch 11/100
157/157 [==============================] - 50s 316ms/step - loss: 1.1290 - accuracy: 0.5994 - top-5-accuracy: 0.9575 - val_loss: 1.1161 - val_accuracy: 0.6064 - val_top-5-accuracy: 0.9618
Epoch 12/100
157/157 [==============================] - 50s 316ms/step - loss: 1.0861 - accuracy: 0.6157 - top-5-accuracy: 0.9602 - val_loss: 1.1220 - val_accuracy: 0.6133 - val_top-5-accuracy: 0.9576
Epoch 13/100
157/157 [==============================] - 49s 315ms/step - loss: 1.0766 - accuracy: 0.6178 - top-5-accuracy: 0.9612 - val_loss: 1.0108 - val_accuracy: 0.6402 - val_top-5-accuracy: 0.9681
Epoch 14/100
157/157 [==============================] - 49s 315ms/step - loss: 1.0179 - accuracy: 0.6416 - top-5-accuracy: 0.9658 - val_loss: 1.0196 - val_accuracy: 0.6405 - val_top-5-accuracy: 0.9667
Epoch 15/100
157/157 [==============================] - 50s 316ms/step - loss: 1.0028 - accuracy: 0.6470 - top-5-accuracy: 0.9678 - val_loss: 1.0113 - val_accuracy: 0.6415 - val_top-5-accuracy: 0.9672
Epoch 16/100
157/157 [==============================] - 50s 316ms/step - loss: 0.9613 - accuracy: 0.6611 - top-5-accuracy: 0.9710 - val_loss: 1.0516 - val_accuracy: 0.6406 - val_top-5-accuracy: 0.9596
Epoch 17/100
157/157 [==============================] - 50s 316ms/step - loss: 0.9262 - accuracy: 0.6740 - top-5-accuracy: 0.9729 - val_loss: 0.9010 - val_accuracy: 0.6844 - val_top-5-accuracy: 0.9750
Epoch 18/100
157/157 [==============================] - 50s 316ms/step - loss: 0.8768 - accuracy: 0.6916 - top-5-accuracy: 0.9769 - val_loss: 0.8862 - val_accuracy: 0.6908 - val_top-5-accuracy: 0.9767
Epoch 19/100
157/157 [==============================] - 49s 315ms/step - loss: 0.8595 - accuracy: 0.6984 - top-5-accuracy: 0.9768 - val_loss: 0.8732 - val_accuracy: 0.6982 - val_top-5-accuracy: 0.9738
Epoch 20/100
157/157 [==============================] - 50s 317ms/step - loss: 0.8252 - accuracy: 0.7103 - top-5-accuracy: 0.9793 - val_loss: 0.9330 - val_accuracy: 0.6745 - val_top-5-accuracy: 0.9718
Epoch 21/100
157/157 [==============================] - 51s 322ms/step - loss: 0.8003 - accuracy: 0.7180 - top-5-accuracy: 0.9814 - val_loss: 0.8912 - val_accuracy: 0.6948 - val_top-5-accuracy: 0.9728
Epoch 22/100
157/157 [==============================] - 51s 326ms/step - loss: 0.7651 - accuracy: 0.7317 - top-5-accuracy: 0.9829 - val_loss: 0.7894 - val_accuracy: 0.7277 - val_top-5-accuracy: 0.9791
Epoch 23/100
157/157 [==============================] - 52s 328ms/step - loss: 0.7372 - accuracy: 0.7415 - top-5-accuracy: 0.9843 - val_loss: 0.7752 - val_accuracy: 0.7284 - val_top-5-accuracy: 0.9804
Epoch 24/100
157/157 [==============================] - 51s 327ms/step - loss: 0.7324 - accuracy: 0.7423 - top-5-accuracy: 0.9852 - val_loss: 0.7949 - val_accuracy: 0.7340 - val_top-5-accuracy: 0.9792
Epoch 25/100
157/157 [==============================] - 51s 323ms/step - loss: 0.7051 - accuracy: 0.7512 - top-5-accuracy: 0.9858 - val_loss: 0.7967 - val_accuracy: 0.7280 - val_top-5-accuracy: 0.9787
Epoch 26/100
157/157 [==============================] - 51s 323ms/step - loss: 0.6832 - accuracy: 0.7577 - top-5-accuracy: 0.9870 - val_loss: 0.7840 - val_accuracy: 0.7322 - val_top-5-accuracy: 0.9807
Epoch 27/100
157/157 [==============================] - 51s 322ms/step - loss: 0.6609 - accuracy: 0.7654 - top-5-accuracy: 0.9877 - val_loss: 0.7447 - val_accuracy: 0.7434 - val_top-5-accuracy: 0.9816
Epoch 28/100
157/157 [==============================] - 50s 319ms/step - loss: 0.6495 - accuracy: 0.7724 - top-5-accuracy: 0.9883 - val_loss: 0.7885 - val_accuracy: 0.7280 - val_top-5-accuracy: 0.9817
Epoch 29/100
157/157 [==============================] - 50s 317ms/step - loss: 0.6491 - accuracy: 0.7707 - top-5-accuracy: 0.9885 - val_loss: 0.7539 - val_accuracy: 0.7458 - val_top-5-accuracy: 0.9821
Epoch 30/100
157/157 [==============================] - 50s 317ms/step - loss: 0.6213 - accuracy: 0.7823 - top-5-accuracy: 0.9888 - val_loss: 0.7571 - val_accuracy: 0.7470 - val_top-5-accuracy: 0.9815
Epoch 31/100
157/157 [==============================] - 50s 318ms/step - loss: 0.5976 - accuracy: 0.7902 - top-5-accuracy: 0.9906 - val_loss: 0.7430 - val_accuracy: 0.7508 - val_top-5-accuracy: 0.9817
Epoch 32/100
157/157 [==============================] - 50s 318ms/step - loss: 0.5932 - accuracy: 0.7898 - top-5-accuracy: 0.9910 - val_loss: 0.7545 - val_accuracy: 0.7469 - val_top-5-accuracy: 0.9793
Epoch 33/100
157/157 [==============================] - 50s 318ms/step - loss: 0.5977 - accuracy: 0.7850 - top-5-accuracy: 0.9913 - val_loss: 0.7200 - val_accuracy: 0.7569 - val_top-5-accuracy: 0.9830
Epoch 34/100
157/157 [==============================] - 50s 317ms/step - loss: 0.5552 - accuracy: 0.8041 - top-5-accuracy: 0.9920 - val_loss: 0.7377 - val_accuracy: 0.7552 - val_top-5-accuracy: 0.9818
Epoch 35/100
157/157 [==============================] - 50s 319ms/step - loss: 0.5509 - accuracy: 0.8056 - top-5-accuracy: 0.9921 - val_loss: 0.8125 - val_accuracy: 0.7331 - val_top-5-accuracy: 0.9782
Epoch 36/100
157/157 [==============================] - 50s 317ms/step - loss: 0.5296 - accuracy: 0.8116 - top-5-accuracy: 0.9933 - val_loss: 0.6900 - val_accuracy: 0.7680 - val_top-5-accuracy: 0.9849
Epoch 37/100
157/157 [==============================] - 50s 316ms/step - loss: 0.5151 - accuracy: 0.8170 - top-5-accuracy: 0.9941 - val_loss: 0.7275 - val_accuracy: 0.7610 - val_top-5-accuracy: 0.9841
Epoch 38/100
157/157 [==============================] - 50s 317ms/step - loss: 0.5069 - accuracy: 0.8217 - top-5-accuracy: 0.9936 - val_loss: 0.7067 - val_accuracy: 0.7703 - val_top-5-accuracy: 0.9835
Epoch 39/100
157/157 [==============================] - 50s 318ms/step - loss: 0.4771 - accuracy: 0.8304 - top-5-accuracy: 0.9945 - val_loss: 0.7110 - val_accuracy: 0.7668 - val_top-5-accuracy: 0.9836
Epoch 40/100
157/157 [==============================] - 50s 317ms/step - loss: 0.4675 - accuracy: 0.8350 - top-5-accuracy: 0.9956 - val_loss: 0.7130 - val_accuracy: 0.7688 - val_top-5-accuracy: 0.9829
Epoch 41/100
157/157 [==============================] - 50s 319ms/step - loss: 0.4586 - accuracy: 0.8382 - top-5-accuracy: 0.9959 - val_loss: 0.7331 - val_accuracy: 0.7598 - val_top-5-accuracy: 0.9806
Epoch 42/100
157/157 [==============================] - 50s 318ms/step - loss: 0.4558 - accuracy: 0.8380 - top-5-accuracy: 0.9959 - val_loss: 0.7187 - val_accuracy: 0.7722 - val_top-5-accuracy: 0.9832
Epoch 43/100
157/157 [==============================] - 50s 320ms/step - loss: 0.4356 - accuracy: 0.8450 - top-5-accuracy: 0.9958 - val_loss: 0.7162 - val_accuracy: 0.7693 - val_top-5-accuracy: 0.9850
Epoch 44/100
157/157 [==============================] - 49s 314ms/step - loss: 0.4425 - accuracy: 0.8433 - top-5-accuracy: 0.9958 - val_loss: 0.7061 - val_accuracy: 0.7698 - val_top-5-accuracy: 0.9853
Epoch 45/100
157/157 [==============================] - 49s 314ms/step - loss: 0.4072 - accuracy: 0.8551 - top-5-accuracy: 0.9967 - val_loss: 0.7025 - val_accuracy: 0.7820 - val_top-5-accuracy: 0.9848
Epoch 46/100
157/157 [==============================] - 49s 314ms/step - loss: 0.3865 - accuracy: 0.8644 - top-5-accuracy: 0.9970 - val_loss: 0.7178 - val_accuracy: 0.7740 - val_top-5-accuracy: 0.9844
Epoch 47/100
157/157 [==============================] - 49s 313ms/step - loss: 0.3718 - accuracy: 0.8694 - top-5-accuracy: 0.9973 - val_loss: 0.7216 - val_accuracy: 0.7768 - val_top-5-accuracy: 0.9828
Epoch 48/100
157/157 [==============================] - 49s 314ms/step - loss: 0.3733 - accuracy: 0.8673 - top-5-accuracy: 0.9970 - val_loss: 0.7440 - val_accuracy: 0.7713 - val_top-5-accuracy: 0.9841
Epoch 49/100
157/157 [==============================] - 49s 313ms/step - loss: 0.3531 - accuracy: 0.8741 - top-5-accuracy: 0.9979 - val_loss: 0.7220 - val_accuracy: 0.7738 - val_top-5-accuracy: 0.9848
Epoch 50/100
157/157 [==============================] - 49s 314ms/step - loss: 0.3502 - accuracy: 0.8738 - top-5-accuracy: 0.9980 - val_loss: 0.7245 - val_accuracy: 0.7734 - val_top-5-accuracy: 0.9836
TESTING
40/40 [==============================] - 2s 56ms/step - loss: 0.7336 - accuracy: 0.7638 - top-5-accuracy: 0.9855
Loss: 0.73
Top 1 test accuracy: 76.38%
Top 5 test accuracy: 98.55%
由于我们通过子类化创建了模型,因此无法以HDF5格式保存模型。
模型只能以TF SavedModel格式保存。一般来说,这也是保存模型的推荐格式。
model.save("ShiftViT")
下载推理样本数据
!wget -q 'https://tinyurl.com/2p9483sw' -O inference_set.zip
!unzip -q inference_set.zip
加载保存的模型
# 自定义对象在模型保存时未被包含。
# 在加载时,需要传入这些对象以重建模型
saved_model = tf.keras.models.load_model(
"ShiftViT",
custom_objects={"WarmUpCosine": WarmUpCosine, "AdamW": tfa.optimizers.AdamW},
)
推理的实用函数
def process_image(img_path):
# 从字符串路径读取图像文件
img = tf.io.read_file(img_path)
# 将jpeg解码为uint8张量
img = tf.io.decode_jpeg(img, channels=3)
# 调整图像大小以匹配模型接受的输入大小
# 使用`method`为`nearest`以保留传递给`resize()`的输入的dtype
img = tf.image.resize(
img, [config.input_shape[0], config.input_shape[1]], method="nearest"
)
return img
def create_tf_dataset(image_dir):
data_dir = pathlib.Path(image_dir)
# 使用图像目录创建tf.data数据集
predict_ds = tf.data.Dataset.list_files(str(data_dir / "*.jpg"), shuffle=False)
# 使用map将字符串路径转换为uint8图像张量
# 设置`num_parallel_calls`有助于并行处理多张图像
predict_ds = predict_ds.map(process_image, num_parallel_calls=AUTO)
# 创建Prefetch数据集以获得更好的延迟和吞吐量
predict_ds = predict_ds.batch(config.tf_ds_batch_size).prefetch(AUTO)
return predict_ds
def predict(predict_ds):
# ShiftViT模型返回logits(非标准化预测)
logits = saved_model.predict(predict_ds)
# 通过调用softmax()对预测进行归一化
probabilities = tf.nn.softmax(logits)
return probabilities
def get_predicted_class(probabilities):
pred_label = np.argmax(probabilities)
predicted_class = config.label_map[pred_label]
return predicted_class
def get_confidence_scores(probabilities):
# 获取按降序排列的概率分数的索引
labels = np.argsort(probabilities)[::-1]
confidences = {
config.label_map[label]: np.round((probabilities[label]) * 100, 2)
for label in labels
}
return confidences
获取预测结果
img_dir = "inference_set"
predict_ds = create_tf_dataset(img_dir)
probabilities = predict(predict_ds)
print(f"probabilities: {probabilities[0]}")
confidences = get_confidence_scores(probabilities[0])
print(confidences)
1/1 [==============================] - 2s 2s/step
probabilities: [8.7329084e-01 1.3162658e-03 6.1781306e-05 1.9132349e-05 4.4482469e-05
1.8182898e-06 2.2834571e-05 1.1466043e-05 1.2504059e-01 1.9084632e-04]
{'airplane': 87.33, 'ship': 12.5, 'automobile': 0.13, 'truck': 0.02, 'bird': 0.01, 'deer': 0.0, 'frog': 0.0, 'cat': 0.0, 'horse': 0.0, 'dog': 0.0}
查看预测结果
plt.figure(figsize=(10, 10))
for images in predict_ds:
for i in range(min(6, probabilities.shape[0])):
ax = plt.subplot(3, 3, i + 1)
plt.imshow(images[i].numpy().astype("uint8"))
predicted_class = get_predicted_class(probabilities[i])
plt.title(predicted_class)
plt.axis("off")
这篇论文最有影响力的贡献不是新颖的架构,而是 无注意力训练的层次化ViTs可以表现得相当不错的想法。这 提出了一个问题,即注意力对于ViTs性能的重要性如何。
对于好奇的读者,我们建议阅读 ConvNexT论文,该论文更多地关注于训练 范式和ViTs的结构细节,而不是提供基于注意力的新颖架构。
致谢:
示例可在HuggingFace上获取
已训练模型 | 演示 |
---|---|