代码示例 / 计算机视觉 / 图像分类使用Swin Transformers

图像分类使用Swin Transformers

作者: Rishit Dagli
创建日期: 2021/09/08
最后修改: 2021/09/08
描述: 使用Swin Transformers进行图像分类,Swin Transformers是计算机视觉的通用骨干网络。

在Colab中查看 GitHub源代码

这个示例实现了 Swin Transformer: Hierarchical Vision Transformer using Shifted Windows 由Liu等人提出的图像分类,并在 CIFAR-100数据集上进行了演示。

Swin Transformer(Shifted Window Transformer)可以作为计算机视觉的通用骨干网络。Swin Transformer是一个分层Transformer,其表示是通过_移动窗口_计算的。移动窗口方案通过限制自注意力计算在不重叠的局部窗口内,同时还允许跨窗口连接,从而带来了更高的效率。这种架构具有在不同尺度上建模信息的灵活性,并且与图像大小具有线性计算复杂度。

此示例要求使用TensorFlow 2.5或更高版本。


设置

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf  # 仅用于tf.data和预处理。
import keras
from keras import layers
from keras import ops

配置超参数

一个关键的参数是patch_size,即输入块的大小。 为了将每个像素作为单独的输入,可以将patch_size设置为 (1, 1)。下面,我们从原始论文在ImageNet-1K上训练设置中获取灵感,保持大部分原始设置以用于本示例。

num_classes = 100
input_shape = (32, 32, 3)

patch_size = (2, 2)  # 2x2大小的块
dropout_rate = 0.03  # 丢弃率
num_heads = 8  # 注意力头
embed_dim = 64  # 嵌入维度
num_mlp = 256  # MLP层大小
# 将嵌入的块转换为带有可学习加法值的查询、键和值
qkv_bias = True
window_size = 2  # 注意力窗口大小
shift_size = 1  # 移动窗口的大小
image_dimension = 32  # 初始图像大小

num_patch_x = input_shape[0] // patch_size[0]
num_patch_y = input_shape[1] // patch_size[1]

learning_rate = 1e-3
batch_size = 128
num_epochs = 40
validation_split = 0.1
weight_decay = 0.0001
label_smoothing = 0.1

准备数据

我们通过keras.datasets加载CIFAR-100数据集, 对图像进行归一化,并将整数标签转换为独热编码向量。

(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
num_train_samples = int(len(x_train) * (1 - validation_split))
num_val_samples = len(x_train) - num_train_samples
x_train, x_val = np.split(x_train, [num_train_samples])
y_train, y_val = np.split(y_train, [num_train_samples])
print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")

plt.figure(figsize=(10, 10))
for i in range(25):
    plt.subplot(5, 5, i + 1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(x_train[i])
plt.show()
x_train shape: (45000, 32, 32, 3) - y_train shape: (45000, 100)
x_test shape: (10000, 32, 32, 3) - y_test shape: (10000, 100)

png


辅助函数

我们创建了两个辅助函数,帮助我们从图像中获取一系列块,合并块,并应用丢弃。

def window_partition(x, window_size):
    _, height, width, channels = x.shape
    patch_num_y = height // window_size
    patch_num_x = width // window_size
    x = ops.reshape(
        x,
        (
            -1,
            patch_num_y,
            window_size,
            patch_num_x,
            window_size,
            channels,
        ),
    )
    x = ops.transpose(x, (0, 1, 3, 2, 4, 5))
    windows = ops.reshape(x, (-1, window_size, window_size, channels))
    return windows


def window_reverse(windows, window_size, height, width, channels):
    patch_num_y = height // window_size
    patch_num_x = width // window_size
    x = ops.reshape(
        windows,
        (
            -1,
            patch_num_y,
            patch_num_x,
            window_size,
            window_size,
            channels,
        ),
    )
    x = ops.transpose(x, (0, 1, 3, 2, 4, 5))
    x = ops.reshape(x, (-1, height, width, channels))
    return x

基于窗口的多头自注意力

通常,Transformer执行全局自注意力,其中计算一个token与所有其他tokens之间的关系。全局计算导致相对于token数量的二次复杂性。在这里,正如原始论文所建议的,我们在局部窗口内以非重叠方式计算自注意力。全局自注意力导致与patch数量相关的二次计算复杂性,而基于窗口的自注意力则导致线性复杂性,并且易于扩展。

class WindowAttention(layers.Layer):
    def __init__(
        self,
        dim,
        window_size,
        num_heads,
        qkv_bias=True,
        dropout_rate=0.0,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.dim = dim
        self.window_size = window_size
        self.num_heads = num_heads
        self.scale = (dim // num_heads) ** -0.5
        self.qkv = layers.Dense(dim * 3, use_bias=qkv_bias)
        self.dropout = layers.Dropout(dropout_rate)
        self.proj = layers.Dense(dim)

        num_window_elements = (2 * self.window_size[0] - 1) * (
            2 * self.window_size[1] - 1
        )
        self.relative_position_bias_table = self.add_weight(
            shape=(num_window_elements, self.num_heads),
            initializer=keras.initializers.Zeros(),
            trainable=True,
        )
        coords_h = np.arange(self.window_size[0])
        coords_w = np.arange(self.window_size[1])
        coords_matrix = np.meshgrid(coords_h, coords_w, indexing="ij")
        coords = np.stack(coords_matrix)
        coords_flatten = coords.reshape(2, -1)
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
        relative_coords = relative_coords.transpose([1, 2, 0])
        relative_coords[:, :, 0] += self.window_size[0] - 1
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)

        self.relative_position_index = keras.Variable(
            initializer=relative_position_index,
            shape=relative_position_index.shape,
            dtype="int",
            trainable=False,
        )

    def call(self, x, mask=None):
        _, size, channels = x.shape
        head_dim = channels // self.num_heads
        x_qkv = self.qkv(x)
        x_qkv = ops.reshape(x_qkv, (-1, size, 3, self.num_heads, head_dim))
        x_qkv = ops.transpose(x_qkv, (2, 0, 3, 1, 4))
        q, k, v = x_qkv[0], x_qkv[1], x_qkv[2]
        q = q * self.scale
        k = ops.transpose(k, (0, 1, 3, 2))
        attn = q @ k

        num_window_elements = self.window_size[0] * self.window_size[1]
        relative_position_index_flat = ops.reshape(self.relative_position_index, (-1,))
        relative_position_bias = ops.take(
            self.relative_position_bias_table,
            relative_position_index_flat,
            axis=0,
        )
        relative_position_bias = ops.reshape(
            relative_position_bias,
            (num_window_elements, num_window_elements, -1),
        )
        relative_position_bias = ops.transpose(relative_position_bias, (2, 0, 1))
        attn = attn + ops.expand_dims(relative_position_bias, axis=0)

        if mask is not None:
            nW = mask.shape[0]
            mask_float = ops.cast(
                ops.expand_dims(ops.expand_dims(mask, axis=1), axis=0),
                "float32",
            )
            attn = ops.reshape(attn, (-1, nW, self.num_heads, size, size)) + mask_float
            attn = ops.reshape(attn, (-1, self.num_heads, size, size))
            attn = keras.activations.softmax(attn, axis=-1)
        else:
            attn = keras.activations.softmax(attn, axis=-1)
        attn = self.dropout(attn)

        x_qkv = attn @ v
        x_qkv = ops.transpose(x_qkv, (0, 2, 1, 3))
        x_qkv = ops.reshape(x_qkv, (-1, size, channels))
        x_qkv = self.proj(x_qkv)
        x_qkv = self.dropout(x_qkv)
        return x_qkv

完整的Swin Transformer模型

最后,通过用移动窗口注意力替换标准的多头注意力(MHA),我们将完整的Swin Transformer结合在一起。正如原始论文所建议的,我们创建一个模型,其中包含一个基于移动窗口的MHA层,后面跟着一个中间带GELU非线性的2层MLP,在每个MSA层和每个MLP之前应用LayerNormalization,并在这些层之后添加残差连接。

注意,我们只创建一个简单的具有2个Dense层和2个Dropout层的MLP。通常你会看到模型使用ResNet-50作为MLP,这在文献中相当标准。然而在本文中,作者使用了一个带GELU非线性的2层MLP。

class SwinTransformer(layers.Layer):
    def __init__(
        self,
        dim,
        num_patch,
        num_heads,
        window_size=7,
        shift_size=0,
        num_mlp=1024,
        qkv_bias=True,
        dropout_rate=0.0,
        **kwargs,
    ):
        super().__init__(**kwargs)

        self.dim = dim  # 输入维度的数量
        self.num_patch = num_patch  # 嵌入补丁的数量
        self.num_heads = num_heads  # 注意力头的数量
        self.window_size = window_size  # 窗口大小
        self.shift_size = shift_size  # 窗口位移大小
        self.num_mlp = num_mlp  # MLP节点的数量

        self.norm1 = layers.LayerNormalization(epsilon=1e-5)
        self.attn = WindowAttention(
            dim,
            window_size=(self.window_size, self.window_size),
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            dropout_rate=dropout_rate,
        )
        self.drop_path = layers.Dropout(dropout_rate)
        self.norm2 = layers.LayerNormalization(epsilon=1e-5)

        self.mlp = keras.Sequential(
            [
                layers.Dense(num_mlp),
                layers.Activation(keras.activations.gelu),
                layers.Dropout(dropout_rate),
                layers.Dense(dim),
                layers.Dropout(dropout_rate),
            ]
        )

        if min(self.num_patch) < self.window_size:
            self.shift_size = 0
            self.window_size = min(self.num_patch)

    def build(self, input_shape):
        if self.shift_size == 0:
            self.attn_mask = None
        else:
            height, width = self.num_patch
            h_slices = (
                slice(0, -self.window_size),
                slice(-self.window_size, -self.shift_size),
                slice(-self.shift_size, None),
            )
            w_slices = (
                slice(0, -self.window_size),
                slice(-self.window_size, -self.shift_size),
                slice(-self.shift_size, None),
            )
            mask_array = np.zeros((1, height, width, 1))
            count = 0
            for h in h_slices:
                for w in w_slices:
                    mask_array[:, h, w, :] = count
                    count += 1
            mask_array = ops.convert_to_tensor(mask_array)

            # 将掩码数组转换为窗口
            mask_windows = window_partition(mask_array, self.window_size)
            mask_windows = ops.reshape(
                mask_windows, [-1, self.window_size * self.window_size]
            )
            attn_mask = ops.expand_dims(mask_windows, axis=1) - ops.expand_dims(
                mask_windows, axis=2
            )
            attn_mask = ops.where(attn_mask != 0, -100.0, attn_mask)
            attn_mask = ops.where(attn_mask == 0, 0.0, attn_mask)
            self.attn_mask = keras.Variable(
                initializer=attn_mask,
                shape=attn_mask.shape,
                dtype=attn_mask.dtype,
                trainable=False,
            )

    def call(self, x, training=False):
        height, width = self.num_patch
        _, num_patches_before, channels = x.shape
        x_skip = x
        x = self.norm1(x)
        x = ops.reshape(x, (-1, height, width, channels))
        if self.shift_size > 0:
            shifted_x = ops.roll(
                x, shift=[-self.shift_size, -self.shift_size], axis=[1, 2]
            )
        else:
            shifted_x = x

        x_windows = window_partition(shifted_x, self.window_size)
        x_windows = ops.reshape(
            x_windows, (-1, self.window_size * self.window_size, channels)
        )
        attn_windows = self.attn(x_windows, mask=self.attn_mask)

        attn_windows = ops.reshape(
            attn_windows,
            (-1, self.window_size, self.window_size, channels),
        )
        shifted_x = window_reverse(
            attn_windows, self.window_size, height, width, channels
        )
        if self.shift_size > 0:
            x = ops.roll(
                shifted_x, shift=[self.shift_size, self.shift_size], axis=[1, 2]
            )
        else:
            x = shifted_x

        x = ops.reshape(x, (-1, height * width, channels))
        x = self.drop_path(x, training=training)
        x = x_skip + x
        x_skip = x
        x = self.norm2(x)
        x = self.mlp(x)
        x = self.drop_path(x)
        x = x_skip + x
        return x

模型训练与评估

提取和嵌入补丁

我们首先创建3层来帮助我们从图像中提取、嵌入和合并补丁,之后将使用我们构建的Swin Transformer类。

# 使用tf操作,因为它只在tf.data中使用。
def patch_extract(images):
    batch_size = tf.shape(images)[0]
    patches = tf.image.extract_patches(
        images=images,
        sizes=(1, patch_size[0], patch_size[1], 1),
        strides=(1, patch_size[0], patch_size[1], 1),
        rates=(1, 1, 1, 1),
        padding="VALID",
    )
    patch_dim = patches.shape[-1]
    patch_num = patches.shape[1]
    return tf.reshape(patches, (batch_size, patch_num * patch_num, patch_dim))


class PatchEmbedding(layers.Layer):
    def __init__(self, num_patch, embed_dim, **kwargs):
        super().__init__(**kwargs)
        self.num_patch = num_patch
        self.proj = layers.Dense(embed_dim)
        self.pos_embed = layers.Embedding(input_dim=num_patch, output_dim=embed_dim)

    def call(self, patch):
        pos = ops.arange(start=0, stop=self.num_patch)
        return self.proj(patch) + self.pos_embed(pos)


class PatchMerging(keras.layers.Layer):
    def __init__(self, num_patch, embed_dim):
        super().__init__()
        self.num_patch = num_patch
        self.embed_dim = embed_dim
        self.linear_trans = layers.Dense(2 * embed_dim, use_bias=False)

    def call(self, x):
        height, width = self.num_patch
        _, _, C = x.shape
        x = ops.reshape(x, (-1, height, width, C))
        x0 = x[:, 0::2, 0::2, :]
        x1 = x[:, 1::2, 0::2, :]
        x2 = x[:, 0::2, 1::2, :]
        x3 = x[:, 1::2, 1::2, :]
        x = ops.concatenate((x0, x1, x2, x3), axis=-1)
        x = ops.reshape(x, (-1, (height // 2) * (width // 2), 4 * C))
        return self.linear_trans(x)

准备tf.data.Dataset

我们使用tf.data完成所有没有可训练权重的步骤。准备训练、验证和测试集。

def augment(x):
    x = tf.image.random_crop(x, size=(image_dimension, image_dimension, 3))
    x = tf.image.random_flip_left_right(x)
    return x


dataset = (
    tf.data.Dataset.from_tensor_slices((x_train, y_train))
    .map(lambda x, y: (augment(x), y))
    .batch(batch_size=batch_size)
    .map(lambda x, y: (patch_extract(x), y))
    .prefetch(tf.data.experimental.AUTOTUNE)
)

dataset_val = (
    tf.data.Dataset.from_tensor_slices((x_val, y_val))
    .batch(batch_size=batch_size)
    .map(lambda x, y: (patch_extract(x), y))
    .prefetch(tf.data.experimental.AUTOTUNE)
)

dataset_test = (
    tf.data.Dataset.from_tensor_slices((x_test, y_test))
    .batch(batch_size=batch_size)
    .map(lambda x, y: (patch_extract(x), y))
    .prefetch(tf.data.experimental.AUTOTUNE)
)

构建模型

我们将Swin Transformer模型结合在一起。

input = layers.Input(shape=(256, 12))
x = PatchEmbedding(num_patch_x * num_patch_y, embed_dim)(input)
x = SwinTransformer(
    dim=embed_dim,
    num_patch=(num_patch_x, num_patch_y),
    num_heads=num_heads,
    window_size=window_size,
    shift_size=0,
    num_mlp=num_mlp,
    qkv_bias=qkv_bias,
    dropout_rate=dropout_rate,
)(x)
x = SwinTransformer(
    dim=embed_dim,
    num_patch=(num_patch_x, num_patch_y),
    num_heads=num_heads,
    window_size=window_size,
    shift_size=shift_size,
    num_mlp=num_mlp,
    qkv_bias=qkv_bias,
    dropout_rate=dropout_rate,
)(x)
x = PatchMerging((num_patch_x, num_patch_y), embed_dim=embed_dim)(x)
x = layers.GlobalAveragePooling1D()(x)
output = layers.Dense(num_classes, activation="softmax")(x)
这台机器上可能存在NVIDIA GPU,但未安装支持CUDA的jaxlib。将回退到cpu。

在CIFAR-100上训练

我们在CIFAR-100上训练模型。在这个例子中,我们仅训练模型40个周期,以保持训练时间短。实际上,您应该训练150个周期以达到收敛。

model = keras.Model(input, output)
model.compile(
    loss=keras.losses.CategoricalCrossentropy(label_smoothing=label_smoothing),
    optimizer=keras.optimizers.AdamW(
        learning_rate=learning_rate, weight_decay=weight_decay
    ),
    metrics=[
        keras.metrics.CategoricalAccuracy(name="accuracy"),
        keras.metrics.TopKCategoricalAccuracy(5, name="top-5-accuracy"),
    ],
)

history = model.fit(
    dataset,
    batch_size=batch_size,
    epochs=num_epochs,
    validation_data=dataset_val,
)
Epoch 1/40
 352/352 ━━━━━━━━━━━━━━━━━━━━ 644s 2s/step - 准确率: 0.0517 - 损失: 4.3948 - 前五名准确率: 0.1816 - 验证准确率: 0.1396 - 验证损失: 3.7930 - 验证前五名准确率: 0.3922
Epoch 2/40
 352/352 ━━━━━━━━━━━━━━━━━━━━ 626s 2s/step - 准确率: 0.1606 - 损失: 3.7267 - 前五名准确率: 0.4209 - 验证准确率: 0.1946 - 验证损失: 3.5560 - 验证前五名准确率: 0.4862
Epoch 3/40
 352/352 ━━━━━━━━━━━━━━━━━━━━ 634s 2s/step - 准确率: 0.2160 - 损失: 3.4910 - 前五名准确率: 0.5076 - 验证准确率: 0.2440 - 验证损失: 3.3946 - 验证前五名准确率: 0.5384
Epoch 4/40
 352/352 ━━━━━━━━━━━━━━━━━━━━ 620s 2s/step - 准确率: 0.2599 - 损失: 3.3266 - 前五名准确率: 0.5628 - 验证准确率: 0.2730 - 验证损失: 3.2732 - 验证前五名准确率: 0.5812
Epoch 5/40
 352/352 ━━━━━━━━━━━━━━━━━━━━ 634s 2s/step - 准确率: 0.2841 - 损失: 3.2082 - 前五名准确率: 0.5988 - 验证准确率: 0.2878 - 验证损失: 3.1837 - 验证前五名准确率: 0.6050
Epoch 6/40
 352/352 ━━━━━━━━━━━━━━━━━━━━ 617s 2s/step - 准确率: 0.3049 - 损失: 3.1199 - 前五名准确率: 0.6262 - 验证准确率: 0.3110 - 验证损失: 3.0970 - 验证前五名准确率: 0.6292
Epoch 7/40
 352/352 ━━━━━━━━━━━━━━━━━━━━ 620s 2s/step - 准确率: 0.3271 - 损失: 3.0387 - 前五名准确率: 0.6501 - 验证准确率: 0.3292 - 验证损失: 3.0374 - 验证前五名准确率: 0.6488
Epoch 8/40
 352/352 ━━━━━━━━━━━━━━━━━━━━ 615s 2s/step - 准确率: 0.3454 - 损失: 2.9764 - 前五名准确率: 0.6679 - 验证准确率: 0.3480 - 验证损失: 2.9921 - 验证前五名准确率: 0.6598
Epoch 9/40
 352/352 ━━━━━━━━━━━━━━━━━━━━ 617s 2s/step - 准确率: 0.3571 - 损失: 2.9272 - 前五名准确率: 0.6801 - 验证准确率: 0.3522 - 验证损失: 2.9585 - 验证前五名准确率: 0.6746
Epoch 10/40
 352/352 ━━━━━━━━━━━━━━━━━━━━ 624s 2s/step - 准确率: 0.3658 - 损失: 2.8809 - 前五名准确率: 0.6924 - 验证准确率: 0.3562 - 验证损失: 2.9364 - 验证前五名准确率: 0.6784
Epoch 11/40
 352/352 ━━━━━━━━━━━━━━━━━━━━ 634s 2s/step - 准确率: 0.3796 - 损失: 2.8425 - 前五名准确率: 0.7021 - 验证准确率: 0.3654 - 验证损失: 2.9100 - 验证前五名准确率: 0.6832
Epoch 12/40
 352/352 ━━━━━━━━━━━━━━━━━━━━ 622s 2s/step - 准确率: 0.3884 - 损失: 2.8113 - 前五名准确率: 0.7103 - 验证准确率: 0.3740 - 验证损失: 2.8808 - 验证前五名准确率: 0.6948
Epoch 13/40
 352/352 ━━━━━━━━━━━━━━━━━━━━ 621s 2s/step - 准确率: 0.3994 - 损失: 2.7718 - 前五名准确率: 0.7239 - 验证准确率: 0.3778 - 验证损失: 2.8637 - 验证前五名准确率: 0.6994
Epoch 14/40
 352/352 ━━━━━━━━━━━━━━━━━━━━ 634s 2s/step - 准确率: 0.4072 - 损失: 2.7491 - 前五名准确率: 0.7271 - 验证准确率: 0.3848 - 验证损失: 2.8533 - 验证前五名准确率: 0.7002
Epoch 15/40
 352/352 ━━━━━━━━━━━━━━━━━━━━ 614s 2s/step - 准确率: 0.4142 - 损失: 2.7180 - 前五名准确率: 0.7344 - 验证准确率: 0.3880 - 验证损失: 2.8383 - 验证前五名准确率: 0.7080
Epoch 16/40
 352/352 ━━━━━━━━━━━━━━━━━━━━ 614s 2s/step - 准确率: 0.4231 - 损失: 2.6918 - 前五名准确率: 0.7392 - 验证准确率: 0.3934 - 验证损失: 2.8323 - 验证前五名准确率: 0.7072
Epoch 17/40
 352/352 ━━━━━━━━━━━━━━━━━━━━ 617s 2s/step - 准确率: 0.4339 - 损失: 2.6633 - 前五名准确率: 0.7484 - 验证准确率: 0.3972 - 验证损失: 2.8237 - 验证前五名准确率: 0.7138
Epoch 18/40
 352/352 ━━━━━━━━━━━━━━━━━━━━ 617s 2s/step - 准确率: 0.4388 - 损失: 2.6436 - 前五名准确率: 0.7506 - 验证准确率: 0.3984 - 验证损失: 2.8119 - 验证前五名准确率: 0.7144
Epoch 19/40
 352/352 ━━━━━━━━━━━━━━━━━━━━ 610s 2s/step - 准确率: 0.4439 - 损失: 2.6251 - 前五名准确率: 0.7552 - 验证准确率: 0.4020 - 验证损失: 2.8044 - 验证前五名准确率: 0.7178
Epoch 20/40
 352/352 ━━━━━━━━━━━━━━━━━━━━ 611s 2s/step - 准确率: 0.4540 - 损失: 2.5989 - 前五名准确率: 0.7652 - 验证准确率: 0.4012 - 验证损失: 2.7969 - 验证前五名准确率: 0.7246
Epoch 21/40
 352/352 ━━━━━━━━━━━━━━━━━━━━ 618s 2s/step - 准确率: 0.4586 - 损失: 2.5760 - 前五名准确率: 0.7684 - 验证准确率: 0.4092 - 验证损失: 2.7807 - 验证前五名准确率: 0.7254
Epoch 22/40
 352/352 ━━━━━━━━━━━━━━━━━━━━ 618s 2s/step - 准确率: 0.4607 - 损失: 2.5624 - 前五名准确率: 0.7724 - 验证准确率: 0.4158 - 验证损失: 2.7721 - 验证前五名准确率: 0.7232
Epoch 23/40
 352/352 ━━━━━━━━━━━━━━━━━━━━ 634s 2s/step - 准确率: 0.4658 - 损失: 2.5407 - 前五名准确率: 0.7786 - 验证准确率: 0.4180 - 验证损失: 2.7767 - 验证前五名准确率: 0.7280
Epoch 24/40
 352/352 ━━━━━━━━━━━━━━━━━━━━ 617s 2s/step - 准确率: 0.4744 - 损失: 2.5233 - 前五名准确率: 0.7840 - 验证准确率: 0.4164 - 验证损失: 2.7707 - 验证前五名准确率: 0.7300
Epoch 25/40
 352/352 ━━━━━━━━━━━━━━━━━━━━ 615s 2s/step - 准确率: 0.4758 - 损失: 2.5129 - 前五名准确率: 0.7847 - 验证准确率: 0.4196 - 验证损失: 2.7677 - 验证前五名准确率: 0.7294
Epoch 26/40
 352/352 ━━━━━━━━━━━━━━━━━━━━ 610s 2s/step - 准确率: 0.4853 - 损失: 2.4954 - 前五名准确率: 0.7863 - 验证准确率: 0.4188 - 验证损失: 2.7571 - 验证前五名准确率: 0.7362
Epoch 27/40
 352/352 ━━━━━━━━━━━━━━━━━━━━ 610s 2s/step - 准确率: 0.4858 - 损失: 2.4785 - 前五名准确率: 0.7928 - 验证准确率: 0.4186 - 验证损失: 2.7615 - 验证前五名准确率: 0.7348
Epoch 28/40
 352/352 ━━━━━━━━━━━━━━━━━━━━ 613s 2s/step - 准确率: 0.4889 - 损失: 2.4691 - 前五名准确率: 0.7945 - 验证准确率: 0.4208 - 验证损失: 2.7561 - 验证前五名准确率: 0.7350
Epoch 29/40
 352/352 ━━━━━━━━━━━━━━━━━━━━ 634s 2s/step - 准确率: 0.4940 - 损失: 2.4592 - 前五名准确率: 0.7992 - 验证准确率: 0.4244 - 验证损失: 2.7546 - 验证前五名准确率: 0.7398
Epoch 30/40
 352/352 ━━━━━━━━━━━━━━━━━━━━ 634s 2s/step - 准确率: 0.4989 - 损失: 2.4391 - 前五名准确率: 0.8025 - 验证准确率: 0.4180 - 验证损失: 2.7861 - 验证前五名准确率: 0.7302
Epoch 31/40
 352/352 ━━━━━━━━━━━━━━━━━━━━ 610s 2s/step - 准确率: 0.4994 - 损失: 2.4354 - 前五名准确率: 0.8032 - 验证准确率: 0.4264 - 验证损失: 2.7608 - 验证前五名准确率: 0.7394
Epoch 32/40
 352/352 ━━━━━━━━━━━━━━━━━━━━ 607s 2s/step - 准确率: 0.5011 - 损失: 2.4238 - 前五名准确率: 0.8090 - 验证准确率: 0.4292 - 验证损失: 2.7625 - 验证前五名准确率: 0.7384
Epoch 33/40
 352/352 ━━━━━━━━━━━━━━━━━━━━ 634s 2s/step - 准确率: 0.5065 - 损失: 2.4144 - 前五名准确率: 0.8085 - 验证准确率: 0.4288 - 验证损失: 2.7517 - 验证前五名准确率: 0.7328
Epoch 34/40
 352/352 ━━━━━━━━━━━━━━━━━━━━ 612s 2s/step - 准确率: 0.5094 - 损失: 2.4099 - 前五名准确率: 0.8093 - 验证准确率: 0.4260 - 验证损失: 2.7550 - 验证前五名准确率: 0.7390
Epoch 35/40
 352/352 ━━━━━━━━━━━━━━━━━━━━ 612s 2s/step - 准确率: 0.5109 - 损失: 2.3980 - 前五名准确率: 0.8115 - 验证准确率: 0.4278 - 验证损失: 2.7496 - 验证前五名准确率: 0.7396
Epoch 36/40
 352/352 ━━━━━━━━━━━━━━━━━━━━ 615s 2s/step - 准确率: 0.5178 - 损失: 2.3868 - 前五名准确率: 0.8139 - 验证准确率: 0.4296 - 验证损失: 2.7519 - 验证前五名准确率: 0.7404
Epoch 37/40
 352/352 ━━━━━━━━━━━━━━━━━━━━ 618s 2s/step - 准确率: 0.5151 - 损失: 2.3842 - 前五名准确率: 0.8150 - 验证准确率: 0.4308 - 验证损失: 2.7504 - 验证前五名准确率: 0.7424
Epoch 38/40
 352/352 ━━━━━━━━━━━━━━━━━━━━ 613s 2s/step - 准确率: 0.5169 - 损失: 2.3798 - 前五名准确率: 0.8159 - 验证准确率: 0.4360 - 验证损失: 2.7522 - 验证前五名准确率: 0.7464
Epoch 39/40
 352/352 ━━━━━━━━━━━━━━━━━━━━ 618s 2s/step - 准确率: 0.5228 - 损失: 2.3641 - 前五名准确率: 0.8201 - 验证准确率: 0.4374 - 验证损失: 2.7386 - 验证前五名准确率: 0.7452
Epoch 40/40
 352/352 ━━━━━━━━━━━━━━━━━━━━ 634s 2s/step - 准确率: 0.5232 - 损失: 2.3633 - 前五名准确率: 0.8212 - 验证准确率: 0.4266 - 验证损失: 2.7614 - 验证前五名准确率: 0.7410

让我们可视化模型的训练进度。

plt.plot(history.history["loss"], label="train_loss")  # 训练损失
plt.plot(history.history["val_loss"], label="val_loss")  # 验证损失
plt.xlabel("Epochs")  # 轮次
plt.ylabel("Loss")  # 损失
plt.title("Train and Validation Losses Over Epochs", fontsize=14)  # 训练和验证损失随轮次变化
plt.legend()
plt.grid()
plt.show()

png

让我们展示在CIFAR-100上训练的最终结果。

loss, accuracy, top_5_accuracy = model.evaluate(dataset_test)  # 评估模型
print(f"Test loss: {round(loss, 2)}")  # 测试损失
print(f"Test accuracy: {round(accuracy * 100, 2)}%")  # 测试精度
print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")  # 测试前五名精度
 79/79 ━━━━━━━━━━━━━━━━━━━━ 26s 325ms/step - accuracy: 0.4474 - loss: 2.7119 - top-5-accuracy: 0.7556
Test loss: 2.7
Test accuracy: 44.8%
Test top 5 accuracy: 75.23%

我们刚训练的Swin Transformer模型仅有152K参数,在40个轮次内达到了约75%的测试前五名精度,并且如上图所示没有过拟合的迹象。这意味着我们可以让这个网络训练更久(或许增加一点正则化)并获得更好的性能。通过其他技术,如余弦衰减学习率调度、其他数据增强技术,性能可以进一步提升。在实验中,我尝试将模型训练150个轮次,增加了一点dropout和更大的嵌入维度,这使得CIFAR-100的测试精度提升到约72%,如截图所示。

训练更久的结果

作者在ImageNet上报告了87.3%的top-1准确率。作者还进行了一系列实验,研究输入尺寸、优化器等是如何影响该模型的最终性能的。作者还展示了将该模型用于目标检测、语义分割和实例分割,并报告了这些任务的竞争性结果。强烈建议你查看原始论文

这个例子受到官方PyTorchTensorFlow实现的启发。