代码示例 / 计算机视觉 / MixUp 数据增强用于图像分类

MixUp 数据增强用于图像分类

作者: Sayak Paul
创建日期: 2021/03/06
最后修改日期: 2023/07/24
描述: 使用 mixup 技术进行图像分类的数据增强。

在 Colab 中查看 GitHub 源代码


介绍

mixup 是一种 与域无关 的数据增强技术,提出于 mixup: Beyond Empirical Risk Minimization 由 Zhang 等人实现。它的公式如下:

(请注意,lambda 值是在 [0, 1] 范围内的值,并从 Beta 分布中取样。)

该技术的命名相当系统。我们实际上是在混合特征及其对应的标签。从实现的角度来看,它很简单。神经网络容易 记忆错误标签。mixup通过将不同的特征结合在一起来放宽这一点(标签也会是一样的),以便 网络不会对特征与其标签之间的关系过于自信。

当我们不确定为给定数据集选择一组增强转换时,mixup尤其有用,例如医学成像数据集。mixup 可以扩展到各种数据类型,例如计算机视觉、自然语言 处理、语音等。


设置

import os

os.environ["KERAS_BACKEND"] = "tensorflow"

import numpy as np
import keras
import matplotlib.pyplot as plt

from keras import layers

# 与 tf.data 预处理相关的 TF 导入
from tensorflow import data as tf_data
from tensorflow import image as tf_image
from tensorflow.random import gamma as tf_random_gamma

准备数据集

在本例中,我们将使用 FashionMNIST 数据集。但这个相同的配方也可以 用于其他分类数据集。

(x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()

x_train = x_train.astype("float32") / 255.0
x_train = np.reshape(x_train, (-1, 28, 28, 1))
y_train = keras.ops.one_hot(y_train, 10)

x_test = x_test.astype("float32") / 255.0
x_test = np.reshape(x_test, (-1, 28, 28, 1))
y_test = keras.ops.one_hot(y_test, 10)

定义超参数

AUTO = tf_data.AUTOTUNE
BATCH_SIZE = 64
EPOCHS = 10

将数据转换为 TensorFlow Dataset 对象

# 留出一些样本来创建我们的验证集
val_samples = 2000
x_val, y_val = x_train[:val_samples], y_train[:val_samples]
new_x_train, new_y_train = x_train[val_samples:], y_train[val_samples:]

train_ds_one = (
    tf_data.Dataset.from_tensor_slices((new_x_train, new_y_train))
    .shuffle(BATCH_SIZE * 100)
    .batch(BATCH_SIZE)
)
train_ds_two = (
    tf_data.Dataset.from_tensor_slices((new_x_train, new_y_train))
    .shuffle(BATCH_SIZE * 100)
    .batch(BATCH_SIZE)
)
# 因为我们将混合图像及其对应的标签,所以我们将从同一训练数据中组合两个随机打乱的数据集。
train_ds = tf_data.Dataset.zip((train_ds_one, train_ds_two))

val_ds = tf_data.Dataset.from_tensor_slices((x_val, y_val)).batch(BATCH_SIZE)

test_ds = tf_data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)

定义 mixup 技术函数

为了执行 mixup 例程,我们使用来自同一数据集的训练数据创建新的虚拟数据集,并应用一个在 [0, 1] 范围内的 lambda 值,这个值从 Beta 分布 中取样——例如,new_x = lambda * x1 + (1 - lambda) * x2(其中 x1x2 是图像),同样的公式也用于标签。

def sample_beta_distribution(size, concentration_0=0.2, concentration_1=0.2):
    gamma_1_sample = tf_random_gamma(shape=[size], alpha=concentration_1)
    gamma_2_sample = tf_random_gamma(shape=[size], alpha=concentration_0)
    return gamma_1_sample / (gamma_1_sample + gamma_2_sample)


def mix_up(ds_one, ds_two, alpha=0.2):
    # 解包两个数据集
    images_one, labels_one = ds_one
    images_two, labels_two = ds_two
    batch_size = keras.ops.shape(images_one)[0]

    # 采样lambda并重塑以进行mixup
    l = sample_beta_distribution(batch_size, alpha, alpha)
    x_l = keras.ops.reshape(l, (batch_size, 1, 1, 1))
    y_l = keras.ops.reshape(l, (batch_size, 1))

    # 通过组合一对图像/标签(来自每个数据集的一个)来对图像和标签进行mixup
    images = images_one * x_l + images_two * (1 - x_l)
    labels = labels_one * y_l + labels_two * (1 - y_l)
    return (images, labels)

注意 在这里,我们将两幅图像组合成一幅图像。从理论上讲,我们可以组合任意数量的图像,但这会导致计算成本的增加。在某些情况下,这也可能不会帮助提高性能。


可视化新的增强数据集

# 首先使用我们的 `mix_up` 工具创建新数据集
train_ds_mu = train_ds.map(
    lambda ds_one, ds_two: mix_up(ds_one, ds_two, alpha=0.2),
    num_parallel_calls=AUTO,
)

# 让我们预览数据集中的9个样本
sample_images, sample_labels = next(iter(train_ds_mu))
plt.figure(figsize=(10, 10))
for i, (image, label) in enumerate(zip(sample_images[:9], sample_labels[:9])):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(image.numpy().squeeze())
    print(label.numpy().tolist())
    plt.axis("off")
[0.0, 0.9964277148246765, 0.0, 0.0, 0.003572270041331649, 0.0, 0.0, 0.0, 0.0, 0.0]
[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[0.0, 0.0, 0.0, 0.0, 0.9794676899909973, 0.02053229510784149, 0.0, 0.0, 0.0, 0.0]
[0.0, 0.0, 0.0, 0.0, 0.9536369442939758, 0.0, 0.0, 0.0, 0.04636305570602417, 0.0]
[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7631776928901672, 0.0, 0.0, 0.23682232201099396]
[0.0, 0.0, 0.045958757400512695, 0.0, 0.0, 0.0, 0.9540412425994873, 0.0, 0.0, 0.0]
[0.0, 0.0, 0.0, 0.0, 2.8015051611873787e-08, 0.0, 0.0, 1.0, 0.0, 0.0]
[0.0, 0.0, 0.0, 0.0003173351287841797, 0.0, 0.9996826648712158, 0.0, 0.0, 0.0, 0.0]

png


模型构建

def get_training_model():
    model = keras.Sequential(
        [
            layers.Input(shape=(28, 28, 1)),
            layers.Conv2D(16, (5, 5), activation="relu"),
            layers.MaxPooling2D(pool_size=(2, 2)),
            layers.Conv2D(32, (5, 5), activation="relu"),
            layers.MaxPooling2D(pool_size=(2, 2)),
            layers.Dropout(0.2),
            layers.GlobalAveragePooling2D(),
            layers.Dense(128, activation="relu"),
            layers.Dense(10, activation="softmax"),
        ]
    )
    return model

为了可重复性,我们对浅层网络的初始随机权重进行序列化。

initial_model = get_training_model()
initial_model.save_weights("initial_weights.weights.h5")

1. 用混合数据集训练模型

model = get_training_model()
model.load_weights("initial_weights.weights.h5")
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
model.fit(train_ds_mu, validation_data=val_ds, epochs=EPOCHS)
_, test_acc = model.evaluate(test_ds)
print("Test accuracy: {:.2f}%".format(test_acc * 100))
Epoch 1/10
  62/907 ━━━━━━━━━━━━━━━━━━━━  2s 3ms/step - accuracy: 0.2518 - loss: 2.2072

WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1699655923.381468   16749 device_compiler.h:187] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.

 907/907 ━━━━━━━━━━━━━━━━━━━━ 13s 9ms/step - accuracy: 0.5335 - loss: 1.4414 - val_accuracy: 0.7635 - val_loss: 0.6678
Epoch 2/10
 907/907 ━━━━━━━━━━━━━━━━━━━━ 12s 4ms/step - accuracy: 0.7168 - loss: 0.9688 - val_accuracy: 0.7925 - val_loss: 0.5849
Epoch 3/10
 907/907 ━━━━━━━━━━━━━━━━━━━━ 5s 4ms/step - accuracy: 0.7525 - loss: 0.8940 - val_accuracy: 0.8290 - val_loss: 0.5138
Epoch 4/10
 907/907 ━━━━━━━━━━━━━━━━━━━━ 4s 3ms/step - accuracy: 0.7742 - loss: 0.8431 - val_accuracy: 0.8360 - val_loss: 0.4726
Epoch 5/10
 907/907 ━━━━━━━━━━━━━━━━━━━━ 3s 3ms/step - accuracy: 0.7876 - loss: 0.8095 - val_accuracy: 0.8550 - val_loss: 0.4450
Epoch 6/10
 907/907 ━━━━━━━━━━━━━━━━━━━━ 3s 3ms/step - accuracy: 0.8029 - loss: 0.7794 - val_accuracy: 0.8560 - val_loss: 0.4178
Epoch 7/10
 907/907 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - accuracy: 0.8039 - loss: 0.7632 - val_accuracy: 0.8600 - val_loss: 0.4056
Epoch 8/10
 907/907 ━━━━━━━━━━━━━━━━━━━━ 3s 3ms/step - accuracy: 0.8115 - loss: 0.7465 - val_accuracy: 0.8510 - val_loss: 0.4114
Epoch 9/10
 907/907 ━━━━━━━━━━━━━━━━━━━━ 3s 3ms/step - accuracy: 0.8115 - loss: 0.7364 - val_accuracy: 0.8645 - val_loss: 0.3983
Epoch 10/10
 907/907 ━━━━━━━━━━━━━━━━━━━━ 3s 3ms/step - accuracy: 0.8182 - loss: 0.7237 - val_accuracy: 0.8630 - val_loss: 0.3735
 157/157 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - accuracy: 0.8610 - loss: 0.4030
Test accuracy: 85.82%

2. 不使用混合数据集训练模型

model = get_training_model()
model.load_weights("initial_weights.weights.h5")
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
# 注意这里我们没有使用混合数据集
model.fit(train_ds_one, validation_data=val_ds, epochs=EPOCHS)
_, test_acc = model.evaluate(test_ds)
print("Test accuracy: {:.2f}%".format(test_acc * 100))
第 1 轮/共 10 轮
 907/907 ━━━━━━━━━━━━━━━━━━━━ 8s 6ms/步 - 准确率: 0.5690 - 损失: 1.1928 - 验证准确率: 0.7585 - 验证损失: 0.6519
第 2 轮/共 10 轮
 907/907 ━━━━━━━━━━━━━━━━━━━━ 5s 2ms/步 - 准确率: 0.7525 - 损失: 0.6484 - 验证准确率: 0.7860 - 验证损失: 0.5799
第 3 轮/共 10 轮
 907/907 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/步 - 准确率: 0.7895 - 损失: 0.5661 - 验证准确率: 0.8205 - 验证损失: 0.5122
第 4 轮/共 10 轮
 907/907 ━━━━━━━━━━━━━━━━━━━━ 3s 2ms/步 - 准确率: 0.8148 - 损失: 0.5126 - 验证准确率: 0.8415 - 验证损失: 0.4375
第 5 轮/共 10 轮
 907/907 ━━━━━━━━━━━━━━━━━━━━ 3s 2ms/步 - 准确率: 0.8306 - 损失: 0.4636 - 验证准确率: 0.8610 - 验证损失: 0.3913
第 6 轮/共 10 轮
 907/907 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/步 - 准确率: 0.8433 - 损失: 0.4312 - 验证准确率: 0.8680 - 验证损失: 0.3734
第 7 轮/共 10 轮
 907/907 ━━━━━━━━━━━━━━━━━━━━ 3s 2ms/步 - 准确率: 0.8544 - 损失: 0.4072 - 验证准确率: 0.8750 - 验证损失: 0.3606
第 8 轮/共 10 轮
 907/907 ━━━━━━━━━━━━━━━━━━━━ 3s 2ms/步 - 准确率: 0.8577 - 损失: 0.3913 - 验证准确率: 0.8735 - 验证损失: 0.3520
第 9 轮/共 10 轮
 907/907 ━━━━━━━━━━━━━━━━━━━━ 3s 2ms/步 - 准确率: 0.8645 - 损失: 0.3803 - 验证准确率: 0.8725 - 验证损失: 0.3536
第 10 轮/共 10 轮
 907/907 ━━━━━━━━━━━━━━━━━━━━ 3s 3ms/步 - 准确率: 0.8686 - 损失: 0.3597 - 验证准确率: 0.8745 - 验证损失: 0.3395
 157/157 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/步 - 准确率: 0.8705 - 损失: 0.3672
测试准确率: 86.92%

鼓励读者在不同领域的不同数据集上尝试 mixup,并实验 lambda 参数。强烈建议您查看 原始论文 - 作者展示了几个关于 mixup 的消融研究,表明它如何改善泛化,并展示了将多个图像结合生成单个图像的结果。


注意事项

  • 使用 mixup,您可以创建合成示例——特别是在您缺乏大型数据集时——而无需承担高计算成本。
  • 标签平滑 和 mixup 通常不太兼容,因为标签平滑已经通过某种因子修改了硬标签。
  • 使用 监督对比学习(SCL)时,mixup 表现不佳,因为 SCL 在其预训练阶段期望真实标签。
  • mixup 的其他几个好处包括(如 论文 中所述)应对对抗性示例的鲁棒性和稳定的 GAN(生成对抗网络)训练。
  • 有许多扩展 mixup 的数据增强技术,例如 CutMixAugMix