代码示例 / 计算机视觉 / RandAugment用于图像分类以提高鲁棒性

RandAugment用于图像分类以提高鲁棒性

作者: Sayak PaulSachin Prasad
创建日期: 2021/03/13
最后修改: 2023/12/12
描述: 使用RandAugment训练图像分类模型以提高鲁棒性。

在Colab中查看 GitHub源代码

数据增强是一种非常有用的技术,可以帮助提高卷积神经网络(CNN)的平移不变性。RandAugment是一种用于视觉数据的随机数据增强程序,提出于RandAugment:具有减少搜索空间的实用自动数据增强。它由强大的增强变换组成,如颜色抖动、高斯模糊、饱和度等,以及更传统的增强变换,如随机裁剪。

这些参数是针对给定数据集和网络架构进行调整的。RandAugment的作者也在原始论文中提供了RandAugment的伪代码(图2)。

最近,它已经成为像Noisy Student Training无监督数据增强用于一致性训练等工作的关键组成部分。它也对EfficientNets的成功至关重要。

pip install keras-cv

导入和设置

import os

os.environ["KERAS_BACKEND"] = "tensorflow"
import keras
import keras_cv
from keras import ops
from keras import layers
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds

tfds.disable_progress_bar()
keras.utils.set_random_seed(42)

加载CIFAR10数据集

在此示例中,我们将使用CIFAR10数据集

(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
print(f"训练样本总数: {len(x_train)}")
print(f"测试样本总数: {len(x_test)}")
训练样本总数: 50000
测试样本总数: 10000

定义超参数

AUTO = tf.data.AUTOTUNE
BATCH_SIZE = 128
EPOCHS = 1
IMAGE_SIZE = 72

初始化RandAugment对象

现在,我们将从imgaug.augmenters模块中初始化一个RandAugment对象,使用RandAugment作者建议的参数。

rand_augment = keras_cv.layers.RandAugment(
    value_range=(0, 255), augmentations_per_image=3, magnitude=0.8
)

创建TensorFlow Dataset对象

train_ds_rand = (
    tf.data.Dataset.from_tensor_slices((x_train, y_train))
    .shuffle(BATCH_SIZE * 100)
    .batch(BATCH_SIZE)
    .map(
        lambda x, y: (tf.image.resize(x, (IMAGE_SIZE, IMAGE_SIZE)), y),
        num_parallel_calls=AUTO,
    )
    .map(
        lambda x, y: (rand_augment(tf.cast(x, tf.uint8)), y),
        num_parallel_calls=AUTO,
    )
    .prefetch(AUTO)
)

test_ds = (
    tf.data.Dataset.from_tensor_slices((x_test, y_test))
    .batch(BATCH_SIZE)
    .map(
        lambda x, y: (tf.image.resize(x, (IMAGE_SIZE, IMAGE_SIZE)), y),
        num_parallel_calls=AUTO,
    )
    .prefetch(AUTO)
)

为了比较,我们还定义了一个简单的增强管道,包括随机翻转、随机旋转和随机缩放。

simple_aug = keras.Sequential(
    [
        layers.Resizing(IMAGE_SIZE, IMAGE_SIZE),
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(factor=0.02),
        layers.RandomZoom(height_factor=0.2, width_factor=0.2),
    ]
)

# 现在,将增强管道映射到我们的训练数据集
train_ds_simple = (
    tf.data.Dataset.from_tensor_slices((x_train, y_train))
    .shuffle(BATCH_SIZE * 100)
    .batch(BATCH_SIZE)
    .map(lambda x, y: (simple_aug(x), y), num_parallel_calls=AUTO)
    .prefetch(AUTO)
)

可视化使用RandAugment增强的数据集

sample_images, _ = next(iter(train_ds_rand))
plt.figure(figsize=(10, 10))
for i, image in enumerate(sample_images[:9]):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(image.numpy().astype("int"))
    plt.axis("off")

png

建议您运行上述代码块几次以查看不同的变体。


可视化使用simple_aug增强的数据集

sample_images, _ = next(iter(train_ds_simple))
plt.figure(figsize=(10, 10))
for i, image in enumerate(sample_images[:9]):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(image.numpy().astype("int"))
    plt.axis("off")  # 关闭坐标轴

png


定义一个模型构建实用函数

现在,我们定义一个基于ResNet50V2架构的CNN模型。同时,注意到网络内部已经包含了一个重缩放层。这消除了对我们的数据集进行任何单独预处理的需要,这在部署时特别有用。

def get_training_model():
    resnet50_v2 = keras.applications.ResNet50V2(
        weights=None,
        include_top=True,
        input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3),
        classes=10,
    )
    model = keras.Sequential(
        [
            layers.Input((IMAGE_SIZE, IMAGE_SIZE, 3)),
            layers.Rescaling(scale=1.0 / 127.5, offset=-1),  # 重缩放
            resnet50_v2,
        ]
    )
    return model

get_training_model().summary()
Model: "sequential_1"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
┃ Layer (type)                     Output Shape                  Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
│ rescaling (Rescaling)           │ (None, 72, 72, 3)         │          0 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ resnet50v2 (Functional)         │ (None, 10)                │ 23,585,290 │
└─────────────────────────────────┴───────────────────────────┴────────────┘
 Total params: 23,585,290 (89.97 MB)
 Trainable params: 23,539,850 (89.80 MB)
 Non-trainable params: 45,440 (177.50 KB)

我们将使用我们数据集的两个不同版本来训练这个网络:

  • 一个是通过RandAugment增强的。
  • 另一个是通过simple_aug增强的。

由于已知RandAugment能够增强模型对常见扰动和损坏的鲁棒性,我们还将在CIFAR-10-C数据集上评估我们的模型,该数据集是Hendrycks等人在神经网络鲁棒性基准测试中提出的。CIFAR-10-C数据集包含19种不同的图像损坏和扰动(例如斑点噪声、雾、Gaussian模糊等),并且严重程度各不相同。对于本例,我们将使用以下配置: cifar10_corrupted/saturate_5。 该配置中的图像如下所示:

为了可重复性,我们序列化我们浅层网络的初始随机权重。

initial_model = get_training_model()
initial_model.save_weights("initial.weights.h5")  # 保存初始权重

用RandAugment训练模型

rand_aug_model = get_training_model()
rand_aug_model.load_weights("initial.weights.h5")
rand_aug_model.compile(
    loss="sparse_categorical_crossentropy", optimizer="adam", metrics=["accuracy"]
)
rand_aug_model.fit(train_ds_rand, validation_data=test_ds, epochs=EPOCHS)
_, test_acc = rand_aug_model.evaluate(test_ds)
print("测试准确率: {:.2f}%".format(test_acc * 100))
 391/391 ━━━━━━━━━━━━━━━━━━━━ 1146s 3s/step - 准确率: 0.1677 - 损失: 2.3232 - 验证准确率: 0.2818 - 验证损失: 1.9966
 79/79 ━━━━━━━━━━━━━━━━━━━━ 39s 489ms/step - 准确率: 0.2803 - 损失: 2.0073
测试准确率: 28.18%

使用 simple_aug 训练模型

simple_aug_model = get_training_model()
simple_aug_model.load_weights("initial.weights.h5")
simple_aug_model.compile(
    loss="sparse_categorical_crossentropy", optimizer="adam", metrics=["accuracy"]
)
simple_aug_model.fit(train_ds_simple, validation_data=test_ds, epochs=EPOCHS)
_, test_acc = simple_aug_model.evaluate(test_ds)
print("测试准确率: {:.2f}%".format(test_acc * 100))
 391/391 ━━━━━━━━━━━━━━━━━━━━ 1132s 3s/step - 准确率: 0.3673 - 损失: 1.7929 - 验证准确率: 0.4789 - 验证损失: 1.4296
 79/79 ━━━━━━━━━━━━━━━━━━━━ 39s 494ms/step - 准确率: 0.4762 - 损失: 1.4368
测试准确率: 47.89%

加载 CIFAR-10-C 数据集并评估性能

# 加载并准备 CIFAR-10-C 数据集
# (如果尚未下载,下载大约需要 10 分钟)
cifar_10_c = tfds.load("cifar10_corrupted/saturate_5", split="test", as_supervised=True)
cifar_10_c = cifar_10_c.batch(BATCH_SIZE).map(
    lambda x, y: (tf.image.resize(x, (IMAGE_SIZE, IMAGE_SIZE)), y),
    num_parallel_calls=AUTO,
)

# 评估 `rand_aug_model`
_, test_acc = rand_aug_model.evaluate(cifar_10_c, verbose=0)
print(
    "CIFAR-10-C (saturate_5) 上 RandAugment 的准确率: {:.2f}%".format(
        test_acc * 100
    )
)

# 评估 `simple_aug_model`
_, test_acc = simple_aug_model.evaluate(cifar_10_c, verbose=0)
print(
    "CIFAR-10-C (saturate_5) 上 simple_aug 的准确率: {:.2f}%".format(
        test_acc * 100
    )
)
 下载和准备数据集 2.72 GiB (下载: 2.72 GiB, 生成:未知大小, 总计: 2.72 GiB)到 /home/sachinprasad/tensorflow_datasets/cifar10_corrupted/saturate_5/1.0.0...
 数据集 cifar10_corrupted 下载并准备到 /home/sachinprasad/tensorflow_datasets/cifar10_corrupted/saturate_5/1.0.0. 后续调用将重用此数据。
CIFAR-10-C (saturate_5) 上 RandAugment 的准确率: 30.36%
CIFAR-10-C (saturate_5) 上 simple_aug 的准确率: 37.18%

为了这个例子,我们只训练模型一个周期。在 CIFAR-10-C 数据集上,使用 RandAugment 的模型可以以更高的准确率表现得更好(例如,在一次实验中为 76.64%),相比之下,使用 simple_aug 训练的模型仅为 64.80%。RandAugment 也能帮助稳定训练。

在笔记本中,您可能会注意到,尽管 RandAugment 增加了训练时间,但我们能够在 CIFAR-10-C 数据集上获得更好的性能。您可以对其他损坏和扰动设置进行实验,这些设置也来自同一 CIFAR-10-C 数据集,看看 RandAugment 是否有帮助。

您还可以尝试在 RandAugment 对象中使用不同的 nm 值。在 原始论文 中,作者展示了特定任务中各个增强变换的影响以及一系列的消融研究。欢迎您查看这些内容。

RandAugment 在提高深度模型的鲁棒性方面显示出了巨大的进步,如 Noisy Student TrainingFixMatch 等工作所示。这使得 RandAugment 成为训练不同视觉模型的非常有用的方案。

您可以使用托管在 Hugging Face Hub 的训练模型,并在 Hugging Face Spaces 尝试演示。