代码示例 / 计算机视觉 / CutMix 数据增强用于图像分类

CutMix 数据增强用于图像分类

作者: Sayan Nath
创建日期: 2021/06/08
最后修改日期: 2023/11/14
描述: 使用 CutMix 进行 CIFAR-10 的图像分类数据增强。

在 Colab 查看 GitHub 源代码


引言

CutMix 是一种数据增强技术,旨在解决区域掉落策略中存在的信息损失和效率低下的问题。 它不是简单地去除像素并用黑、灰色像素或高斯噪声填充,而是用另一张图像的一个块替代被去除的区域,同时根据合成图像的像素数量按比例混合真实标签。 CutMix 在 CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (Yun et al., 2019)中提出。

通过以下公式实现:

其中 M 是二进制掩码,指示来自两张随机抽取图像的切除和填充区域,λ (在 [0, 1] 内)来自 Beta(α, α) 分布

边界框的坐标为:

指示图像的切除和填充区域。 边界框采样由以下公式表示:

其中 rx, ry 是从一个具有上界的均匀分布中随机抽取的。


设置

import numpy as np
import keras
import matplotlib.pyplot as plt

from keras import layers

# 与 tf.data 预处理相关的 TF 导入
from tensorflow import clip_by_value
from tensorflow import data as tf_data
from tensorflow import image as tf_image
from tensorflow import random as tf_random

keras.utils.set_random_seed(42)

加载 CIFAR-10 数据集

在此示例中,我们将使用 CIFAR-10 图像分类数据集

(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
y_train = keras.utils.to_categorical(y_train, num_classes=10)
y_test = keras.utils.to_categorical(y_test, num_classes=10)

print(x_train.shape)
print(y_train.shape)
print(x_test.shape)
print(y_test.shape)

class_names = [
    "飞机",
    "汽车",
    "鸟",
    "猫",
    "鹿",
    "狗",
    "青蛙",
    "马",
    "船",
    "卡车",
]
(50000, 32, 32, 3)
(50000, 10)
(10000, 32, 32, 3)
(10000, 10)

定义超参数

AUTO = tf_data.AUTOTUNE
BATCH_SIZE = 32
IMG_SIZE = 32

定义图像预处理函数

def preprocess_image(image, label):
    image = tf_image.resize(image, (IMG_SIZE, IMG_SIZE))
    image = tf_image.convert_image_dtype(image, "float32") / 255.0
    label = keras.ops.cast(label, dtype="float32")
    return image, label

将数据转换为 TensorFlow Dataset 对象

train_ds_one = (
    tf_data.Dataset.from_tensor_slices((x_train, y_train))
    .shuffle(1024)
    .map(preprocess_image, num_parallel_calls=AUTO)
)
train_ds_two = (
    tf_data.Dataset.from_tensor_slices((x_train, y_train))
    .shuffle(1024)
    .map(preprocess_image, num_parallel_calls=AUTO)
)

train_ds_simple = tf_data.Dataset.from_tensor_slices((x_train, y_train))

test_ds = tf_data.Dataset.from_tensor_slices((x_test, y_test))

train_ds_simple = (
    train_ds_simple.map(preprocess_image, num_parallel_calls=AUTO)
    .batch(BATCH_SIZE)
    .prefetch(AUTO)
)

# 从相同的训练数据组合两个打乱的数据集。
train_ds = tf_data.Dataset.zip((train_ds_one, train_ds_two))

test_ds = (
    test_ds.map(preprocess_image, num_parallel_calls=AUTO)
    .batch(BATCH_SIZE)
    .prefetch(AUTO)
)

定义 CutMix 数据增强函数

CutMix 函数接受两个 imagelabel 对来执行增强。 它从 Beta 分布 中抽样 λ(l) 并返回一个来自 get_box 函数的边界框。然后,我们裁剪第二张图像(image2)并在最终填充的图像中在相同位置填充该图像。

def sample_beta_distribution(size, concentration_0=0.2, concentration_1=0.2):
    gamma_1_sample = tf_random.gamma(shape=[size], alpha=concentration_1)
    gamma_2_sample = tf_random.gamma(shape=[size], alpha=concentration_0)
    return gamma_1_sample / (gamma_1_sample + gamma_2_sample)


def get_box(lambda_value):
    cut_rat = keras.ops.sqrt(1.0 - lambda_value)

    cut_w = IMG_SIZE * cut_rat  # 宽度
    cut_w = keras.ops.cast(cut_w, "int32")

    cut_h = IMG_SIZE * cut_rat  # 高度
    cut_h = keras.ops.cast(cut_h, "int32")

    cut_x = keras.random.uniform((1,), minval=0, maxval=IMG_SIZE)  # x坐标
    cut_x = keras.ops.cast(cut_x, "int32")
    cut_y = keras.random.uniform((1,), minval=0, maxval=IMG_SIZE)  # y坐标
    cut_y = keras.ops.cast(cut_y, "int32")

    boundaryx1 = clip_by_value(cut_x[0] - cut_w // 2, 0, IMG_SIZE)
    boundaryy1 = clip_by_value(cut_y[0] - cut_h // 2, 0, IMG_SIZE)
    bbx2 = clip_by_value(cut_x[0] + cut_w // 2, 0, IMG_SIZE)
    bby2 = clip_by_value(cut_y[0] + cut_h // 2, 0, IMG_SIZE)

    target_h = bby2 - boundaryy1
    if target_h == 0:
        target_h += 1

    target_w = bbx2 - boundaryx1
    if target_w == 0:
        target_w += 1

    return boundaryx1, boundaryy1, target_h, target_w


def cutmix(train_ds_one, train_ds_two):
    (image1, label1), (image2, label2) = train_ds_one, train_ds_two

    alpha = [0.25]
    beta = [0.25]

    # 从Beta分布中获取一个样本
    lambda_value = sample_beta_distribution(1, alpha, beta)

    # 定义Lambda
    lambda_value = lambda_value[0][0]

    # 获取边界框的偏移、高度和宽度
    boundaryx1, boundaryy1, target_h, target_w = get_box(lambda_value)

    # 从第二幅图像获取一个补丁(`image2`)
    crop2 = tf_image.crop_to_bounding_box(
        image2, boundaryy1, boundaryx1, target_h, target_w
    )
    # 用相同的偏移量对`image2`补丁(`crop2`)进行填充
    image2 = tf_image.pad_to_bounding_box(
        crop2, boundaryy1, boundaryx1, IMG_SIZE, IMG_SIZE
    )
    # 从第一幅图像获取一个补丁(`image1`)
    crop1 = tf_image.crop_to_bounding_box(
        image1, boundaryy1, boundaryx1, target_h, target_w
    )
    # 用相同的偏移量对`image1`补丁(`crop1`)进行填充
    img1 = tf_image.pad_to_bounding_box(
        crop1, boundaryy1, boundaryx1, IMG_SIZE, IMG_SIZE
    )

    # 通过从`image1`中减去补丁来修改第一幅图像
    # (在应用`image2`补丁之前)
    image1 = image1 - img1
    # 将修改后的`image1`和`image2`相加以获得CutMix图像
    image = image1 + image2

    # 根据像素比率调整Lambda
    lambda_value = 1 - (target_w * target_h) / (IMG_SIZE * IMG_SIZE)
    lambda_value = keras.ops.cast(lambda_value, "float32")

    # 合并两幅图像的标签
    label = lambda_value * label1 + (1 - lambda_value) * label2
    return image, label

注意: 我们正在合并两个图像以创建一个单一的图像。


在应用了 CutMix 增强后可视化新数据集

# 使用我们的 `cutmix` 工具创建新数据集
train_ds_cmu = (
    train_ds.shuffle(1024)
    .map(cutmix, num_parallel_calls=AUTO)
    .batch(BATCH_SIZE)
    .prefetch(AUTO)
)

# 让我们预览数据集中的 9 个样本
image_batch, label_batch = next(iter(train_ds_cmu))
plt.figure(figsize=(10, 10))
for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.title(class_names[np.argmax(label_batch[i])])
    plt.imshow(image_batch[i])
    plt.axis("off")

png


定义 ResNet-20 模型

def resnet_layer(
    inputs,
    num_filters=16,
    kernel_size=3,
    strides=1,
    activation="relu",
    batch_normalization=True,
    conv_first=True,
):
    conv = layers.Conv2D(
        num_filters,
        kernel_size=kernel_size,
        strides=strides,
        padding="same",
        kernel_initializer="he_normal",
        kernel_regularizer=keras.regularizers.L2(1e-4),
    )
    x = inputs
    if conv_first:
        x = conv(x)
        if batch_normalization:
            x = layers.BatchNormalization()(x)
        if activation is not None:
            x = layers.Activation(activation)(x)
    else:
        if batch_normalization:
            x = layers.BatchNormalization()(x)
        if activation is not None:
            x = layers.Activation(activation)(x)
        x = conv(x)
    return x


def resnet_v20(input_shape, depth, num_classes=10):
    if (depth - 2) % 6 != 0:
        raise ValueError("深度应该是 6n+2 (例如 20, 32, 44 在 [a] 中)")
    # 开始模型定义。
    num_filters = 16
    num_res_blocks = int((depth - 2) / 6)

    inputs = layers.Input(shape=input_shape)
    x = resnet_layer(inputs=inputs)
    # 实例化残差单元堆栈
    for stack in range(3):
        for res_block in range(num_res_blocks):
            strides = 1
            if stack > 0 and res_block == 0:  # 第一层但不是第一堆
                strides = 2  # 下采样
            y = resnet_layer(inputs=x, num_filters=num_filters, strides=strides)
            y = resnet_layer(inputs=y, num_filters=num_filters, activation=None)
            if stack > 0 and res_block == 0:  # 第一层但不是第一堆
                # 线性投影残差快捷连接以匹配
                # 更改的维度
                x = resnet_layer(
                    inputs=x,
                    num_filters=num_filters,
                    kernel_size=1,
                    strides=strides,
                    activation=None,
                    batch_normalization=False,
                )
            x = layers.add([x, y])
            x = layers.Activation("relu")(x)
        num_filters *= 2

    # 在顶部添加分类器。
    # v1 在最后的快捷连接-ReLU 之后不使用 BN
    x = layers.AveragePooling2D(pool_size=8)(x)
    y = layers.Flatten()(x)
    outputs = layers.Dense(
        num_classes, activation="softmax", kernel_initializer="he_normal"
    )(y)

    # 实例化模型。
    model = keras.Model(inputs=inputs, outputs=outputs)
    return model


def training_model():
    return resnet_v20((32, 32, 3), 20)


initial_model = training_model()
initial_model.save_weights("initial_weights.weights.h5")

使用 CutMix 增强的数据集训练模型

model = training_model()
model.load_weights("initial_weights.weights.h5")

model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
model.fit(train_ds_cmu, validation_data=test_ds, epochs=15)

test_loss, test_accuracy = model.evaluate(test_ds)
print("测试准确率: {:.2f}%".format(test_accuracy * 100))
Epoch 1/15
   10/1563 ━━━━━━━━━━━━━━━━━━━━  19s 13ms/step - 准确率: 0.0795 - 损失: 5.3035

WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1699988196.560261  362411 device_compiler.h:187] 使用 XLA 编译集群!这条信息在进程的生命周期中最多记录一次。

 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 64s 27ms/step - 准确率: 0.3148 - 损失: 2.1918 - 验证准确率: 0.4067 - 验证损失: 1.8339
Epoch 2/15
 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 27s 17ms/step - 准确率: 0.4295 - 损失: 1.9021 - 验证准确率: 0.5516 - 验证损失: 1.4744
Epoch 3/15
 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 28s 18ms/step - 准确率: 0.4883 - 损失: 1.8076 - 验证准确率: 0.5305 - 验证损失: 1.5067
Epoch 4/15
 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 27s 17ms/step - 准确率: 0.5243 - 损失: 1.7342 - 验证准确率: 0.6303 - 验证损失: 1.2822
Epoch 5/15
 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 27s 17ms/step - 准确率: 0.5574 - 损失: 1.6614 - 验证准确率: 0.5370 - 验证损失: 1.5912
Epoch 6/15
 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 27s 17ms/step - 准确率: 0.5832 - 损失: 1.6167 - 验证准确率: 0.6254 - 验证损失: 1.3116
Epoch 7/15
 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 26s 17ms/step - 准确率: 0.6045 - 损失: 1.5738 - 验证准确率: 0.6101 - 验证损失: 1.3408
Epoch 8/15
 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 28s 18ms/step - 准确率: 0.6170 - 损失: 1.5493 - 验证准确率: 0.6209 - 验证损失: 1.2923
Epoch 9/15
 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 29s 18ms/step - 准确率: 0.6292 - 损失: 1.5299 - 验证准确率: 0.6290 - 验证损失: 1.2813
Epoch 10/15
 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 28s 18ms/step - 准确率: 0.6394 - 损失: 1.5110 - 验证准确率: 0.7234 - 验证损失: 1.0608
Epoch 11/15
 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 26s 17ms/step - 准确率: 0.6467 - 损失: 1.4915 - 验证准确率: 0.7498 - 验证损失: 0.9854
Epoch 12/15
 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 28s 18ms/step - 准确率: 0.6559 - 损失: 1.4785 - 验证准确率: 0.6481 - 验证损失: 1.2410
Epoch 13/15
 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 26s 17ms/step - 准确率: 0.6596 - 损失: 1.4656 - 验证准确率: 0.7551 - 验证损失: 0.9784
Epoch 14/15
 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 27s 17ms/step - 准确率: 0.6577 - 损失: 1.4637 - 验证准确率: 0.6822 - 验证损失: 1.1703
Epoch 15/15
 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 26s 17ms/step - 准确率: 0.6702 - 损失: 1.4445 - 验证准确率: 0.7108 - 验证损失: 1.0805
 313/313 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - 准确率: 0.7140 - 损失: 1.0766
测试准确率: 71.08%

使用原始未增强数据集训练模型

model = training_model()
model.load_weights("initial_weights.weights.h5")
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
model.fit(train_ds_simple, validation_data=test_ds, epochs=15)

test_loss, test_accuracy = model.evaluate(test_ds)
print("测试准确率: {:.2f}%".format(test_accuracy * 100))
Epoch 1/15
 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 41s 15ms/step - accuracy: 0.3943 - loss: 1.8736 - val_accuracy: 0.5359 - val_loss: 1.4376
Epoch 2/15
 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 11s 7ms/step - accuracy: 0.6160 - loss: 1.2407 - val_accuracy: 0.5887 - val_loss: 1.4254
Epoch 3/15
 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 11s 7ms/step - accuracy: 0.6927 - loss: 1.0448 - val_accuracy: 0.6102 - val_loss: 1.4850
Epoch 4/15
 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 12s 7ms/step - accuracy: 0.7411 - loss: 0.9222 - val_accuracy: 0.6262 - val_loss: 1.3898
Epoch 5/15
 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 13s 8ms/step - accuracy: 0.7711 - loss: 0.8439 - val_accuracy: 0.6283 - val_loss: 1.3425
Epoch 6/15
 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 12s 8ms/step - accuracy: 0.7983 - loss: 0.7886 - val_accuracy: 0.2460 - val_loss: 5.6869
Epoch 7/15
 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 11s 7ms/step - accuracy: 0.8168 - loss: 0.7490 - val_accuracy: 0.1954 - val_loss: 21.7670
Epoch 8/15
 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 11s 7ms/step - accuracy: 0.8113 - loss: 0.7779 - val_accuracy: 0.1027 - val_loss: 36.3144
Epoch 9/15
 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 11s 7ms/step - accuracy: 0.6592 - loss: 1.4179 - val_accuracy: 0.1025 - val_loss: 40.0770
Epoch 10/15
 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 12s 8ms/step - accuracy: 0.5611 - loss: 1.9856 - val_accuracy: 0.1699 - val_loss: 40.6308
Epoch 11/15
 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 13s 8ms/step - accuracy: 0.6076 - loss: 1.7795 - val_accuracy: 0.1003 - val_loss: 63.4775
Epoch 12/15
 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 12s 7ms/step - accuracy: 0.6175 - loss: 1.8077 - val_accuracy: 0.1099 - val_loss: 21.9148
Epoch 13/15
 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 12s 7ms/step - accuracy: 0.6468 - loss: 1.6702 - val_accuracy: 0.1576 - val_loss: 72.7290
Epoch 14/15
 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 12s 7ms/step - accuracy: 0.6437 - loss: 1.7858 - val_accuracy: 0.1000 - val_loss: 64.9249
Epoch 15/15
 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 13s 8ms/step - accuracy: 0.6587 - loss: 1.7587 - val_accuracy: 0.1000 - val_loss: 138.8463
 313/313 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - accuracy: 0.0988 - loss: 139.3117
测试准确率: 10.00%

备注

在这个例子中,我们训练了我们的模型 15 个周期。 在我们的实验中,使用 CutMix 的模型在 CIFAR-10 数据集上达到了更好的准确率 (在我们的实验中为 77.34%)相比于不使用增强的模型(66.90%)。 您可能会注意到,使用 CutMix 增强训练模型所需的时间更少。

您可以通过遵循 原始论文 进一步实验 CutMix 技术。