代码示例 / 计算机视觉 / 图像分类使用大迁移学习(BiT)

图像分类使用大迁移学习(BiT)

作者: Sayan Nath
创建日期: 2021/09/24
最后修改日期: 2024/01/03
描述: 大迁移学习(BiT)在图像分类中的最先进迁移学习方法。

在Colab中查看 GitHub源码


介绍

大迁移学习(也称为BiT)是一种最先进的图像分类迁移学习方法。预训练表征的迁移提高了样本效率,并在训练深度神经网络进行视觉任务时简化了超参数调整。BiT重新审视在大型监督数据集上进行预训练并在目标任务上微调模型的范式。适当选择归一化层以及根据预训练数据量增加而扩展架构容量的重要性。

大迁移学习(BiT)是在公共数据集上训练的,代码在 TF2、Jax和Pytorch中提供。这将帮助任何人达到他们感兴趣任务的最先进性能,即使每类只有少量标记图像。

您可以找到在 ImageNet和ImageNet-21k上预训练的BiT模型,在 TFHub作为TensorFlow2 SavedModels,您可以轻松用作Keras层。对于计算和内存预算较大但对准确性要求较高的用户,有多种大小可供选择,从标准ResNet50到ResNet152x4(152层深,宽度为典型ResNet50的4倍)。

图:x轴显示每个类别使用的图像数量,从1到完整数据集。在左侧的图中,上面的蓝色曲线是我们的BiT-L模型,而下面的曲线是预训练于ImageNet(ILSVRC-2012)的ResNet-50。


设置

import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import keras
from keras import ops
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_datasets as tfds

tfds.disable_progress_bar()

SEEDS = 42

keras.utils.set_random_seed(SEEDS)

收集花卉数据集

train_ds, validation_ds = tfds.load(
    "tf_flowers",
    split=["train[:85%]", "train[85%:]"],
    as_supervised=True,
)
下载并准备数据集 218.21 MiB(下载:218.21 MiB,生成:221.83 MiB,总计:440.05 MiB)到 ~/tensorflow_datasets/tf_flowers/3.0.1...
数据集tf_flowers已下载并准备好到~/tensorflow_datasets/tf_flowers/3.0.1。后续调用将重用此数据。

可视化数据集

plt.figure(figsize=(10, 10))
for i, (image, label) in enumerate(train_ds.take(9)):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(image)
    plt.title(int(label))
    plt.axis("off")

png


定义超参数

RESIZE_TO = 384
CROP_TO = 224
BATCH_SIZE = 64
STEPS_PER_EPOCH = 10
AUTO = tf.data.AUTOTUNE  # 优化管道性能
NUM_CLASSES = 5  # 类别数量
SCHEDULE_LENGTH = (
    500  # 我们将对较低分辨率的图像进行训练,仍然会获得良好的结果
)
SCHEDULE_BOUNDARIES = [
    200,
    300,
    400,
]  # 数据集大小越大,计划长度增加

超参数如SCHEDULE_LENGTHSCHEDULE_BOUNDARIES是根据经验结果确定的。该方法在原始论文和他们的Google AI博客文章中进行了说明。

SCHEDULE_LENGTH还决定是否使用MixUp增强。您也可以在Keras编码示例中找到简单的MixUp实现。


定义预处理辅助函数

SCHEDULE_LENGTH = SCHEDULE_LENGTH * 512 / BATCH_SIZE

random_flip = keras.layers.RandomFlip("horizontal")
random_crop = keras.layers.RandomCrop(CROP_TO, CROP_TO)

def preprocess_train(image, label):
    image = random_flip(image)
    image = ops.image.resize(image, (RESIZE_TO, RESIZE_TO))
    image = random_crop(image)
    image = image / 255.0
    return (image, label)


def preprocess_test(image, label):
    image = ops.image.resize(image, (RESIZE_TO, RESIZE_TO))
    image = ops.cast(image, dtype="float32")
    image = image / 255.0
    return (image, label)


DATASET_NUM_TRAIN_EXAMPLES = train_ds.cardinality().numpy()

repeat_count = int(
    SCHEDULE_LENGTH * BATCH_SIZE / DATASET_NUM_TRAIN_EXAMPLES * STEPS_PER_EPOCH
)
repeat_count += 50 + 1  # 确保至少进行50个训练周期

定义数据管道

# 训练管道
pipeline_train = (
    train_ds.shuffle(10000)
    .repeat(repeat_count)  # 重复数据集大小 / 步骤数量
    .map(preprocess_train, num_parallel_calls=AUTO)
    .batch(BATCH_SIZE)
    .prefetch(AUTO)
)

# 验证管道
pipeline_validation = (
    validation_ds.map(preprocess_test, num_parallel_calls=AUTO)
    .batch(BATCH_SIZE)
    .prefetch(AUTO)
)

可视化训练样本

image_batch, label_batch = next(iter(pipeline_train))

plt.figure(figsize=(10, 10))
for n in range(25):
    ax = plt.subplot(5, 5, n + 1)
    plt.imshow(image_batch[n])
    plt.title(label_batch[n].numpy())
    plt.axis("off")

png


将预训练的TF-Hub模型加载到KerasLayer

bit_model_url = "https://tfhub.dev/google/bit/m-r50x1/1"
bit_module = hub.load(bit_model_url)

创建BigTransfer (BiT)模型

要创建新的模型,我们:

  1. 剪掉BiT模型的原始头。这将我们留下“预激活”输出。 如果我们使用“特征提取器”模型(即所有位于名为feature_vectors的子目录中的模型),我们就不必这样做,因为对于这些模型,头部已经被剪掉。

  2. 添加一个新的头,其输出数量等于我们新任务的类别数量。请注意,初始化头部为零是很重要的。

class MyBiTModel(keras.Model):
    def __init__(self, num_classes, module, **kwargs):
        super().__init__(**kwargs)

        self.num_classes = num_classes
        self.head = keras.layers.Dense(num_classes, kernel_initializer="zeros")
        self.bit_model = module

    def call(self, images):
        bit_embedding = self.bit_model(images)
        return self.head(bit_embedding)


model = MyBiTModel(num_classes=NUM_CLASSES, module=bit_module)

定义优化器和损失

learning_rate = 0.003 * BATCH_SIZE / 512

# 在 SCHEDULE_BOUNDARIES 处分数学习率
lr_schedule = keras.optimizers.schedules.PiecewiseConstantDecay(
    boundaries=SCHEDULE_BOUNDARIES,
    values=[
        learning_rate,
        learning_rate * 0.1,
        learning_rate * 0.01,
        learning_rate * 0.001,
    ],
)
optimizer = keras.optimizers.SGD(learning_rate=lr_schedule, momentum=0.9)

loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

编译模型

model.compile(optimizer=optimizer, loss=loss_fn, metrics=["accuracy"])

设置回调

train_callbacks = [
    keras.callbacks.EarlyStopping(
        monitor="val_accuracy", patience=2, restore_best_weights=True
    )
]

训练模型

history = model.fit(
    pipeline_train,
    batch_size=BATCH_SIZE,
    epochs=int(SCHEDULE_LENGTH / STEPS_PER_EPOCH),
    steps_per_epoch=STEPS_PER_EPOCH,
    validation_data=pipeline_validation,
    callbacks=train_callbacks,
)
Epoch 1/400
10/10 [==============================] - 18s 852ms/step - loss: 0.7465 - accuracy: 0.7891 - val_loss: 0.1865 - val_accuracy: 0.9582
Epoch 2/400
10/10 [==============================] - 5s 529ms/step - loss: 0.1389 - accuracy: 0.9578 - val_loss: 0.1075 - val_accuracy: 0.9727
Epoch 3/400
10/10 [==============================] - 5s 520ms/step - loss: 0.1720 - accuracy: 0.9391 - val_loss: 0.0858 - val_accuracy: 0.9727
Epoch 4/400
10/10 [==============================] - 5s 525ms/step - loss: 0.1211 - accuracy: 0.9516 - val_loss: 0.0833 - val_accuracy: 0.9691

绘制训练和验证指标

def plot_hist(hist):
    plt.plot(hist.history["accuracy"])
    plt.plot(hist.history["val_accuracy"])
    plt.plot(hist.history["loss"])
    plt.plot(hist.history["val_loss"])
    plt.title("训练进度")
    plt.ylabel("准确率/损失")
    plt.xlabel("历次")
    plt.legend(["train_acc", "val_acc", "train_loss", "val_loss"], loc="upper left")
    plt.show()


plot_hist(history)

png


评估模型

accuracy = model.evaluate(pipeline_validation)[1] * 100
print("准确率: {:.2f}%".format(accuracy))
9/9 [==============================] - 3s 364ms/step - loss: 0.1075 - accuracy: 0.9727
准确率: 97.27%

结论

BiT在出乎意料的范围内的数据显示良好 – 从每个类别1个示例到总共100万个示例。BiT在ILSVRC-2012上达到87.5%的顶级准确率,CIFAR-10上的准确率为99.4%,在19个任务的视觉任务适应基准(VTAB)上为76.3%。在小数据集上,BiT在类上有10个示例时,在ILSVRC-2012上达到76.8%,在CIFAR-10上达到97.0%。

您可以通过遵循BigTransfer方法进一步进行实验。 原始论文

HuggingFace上提供示例 | 训练模型 | 演示 | | :–: | :–: | | Generic badge | Generic badge |