开发者指南 / 迁移学习与微调

迁移学习与微调

作者: fchollet
创建日期: 2020/04/15
最后修改: 2023/06/25
描述: Keras 中迁移学习与微调的完整指南。

在 Colab 中查看 GitHub 源码


设置

import numpy as np
import keras
from keras import layers
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt

介绍

迁移学习 由以下几个步骤组成:将一个问题上学到的特征应用于一个新的相似问题。例如,从已学习识别浣熊的模型中提取的特征,可能对启动一个用于识别狸猫的模型有用。

迁移学习通常用于数据集数据量过少的任务,以至于无法从头开始训练一个完整的模型。

在深度学习的上下文中,迁移学习的最常见形式是以下工作流程:

  1. 从一个先前训练的模型中提取层。
  2. 将它们固定,以避免在未来的训练回合中破坏它们所包含的信息。
  3. 在固定层的顶部添加一些新的可训练层。这些层将学习将旧特征转换为新数据集上的预测。
  4. 在您的数据集上训练新的层。

最后一个可选步骤是 微调,即解冻上面获得的整个模型(或其部分),并使用非常低的学习率在新数据上重新训练它。这可能通过逐步适应预训练特征到新数据来实现显著改善。

首先,我们将详细介绍 Keras 的 trainable API,它是大多数迁移学习与微调工作流程的基础。

然后,我们将通过获取在 ImageNet 数据集上预训练的模型,并在 Kaggle "猫与狗" 分类数据集上重新训练它,演示典型的工作流程。

这部分内容改编自 Deep Learning with Python 和 2016 年的博客文章 "使用很少的数据构建强大的图像分类模型"


冻结层:理解 trainable 属性

层和模型有三个权重属性:

  • weights 是层的所有权重变量的列表。
  • trainable_weights 是那些旨在被更新的列表(通过梯度下降)以最小化训练过程中的损失。
  • non_trainable_weights 是那些不打算被训练的列表。通常它们在前向传播期间由模型更新。

示例:Dense 层有 2 个可训练权重(权重和偏置)

layer = keras.layers.Dense(3)
layer.build((None, 4))  # 创建权重

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2
trainable_weights: 2
non_trainable_weights: 0

一般来说,所有权重都是可训练权重。唯一具有非可训练权重的内置层是 BatchNormalization 层。它使用非可训练权重来跟踪其输入在训练过程中的均值和方差。 要学习如何在您自己的自定义层中使用非可训练权重,请参阅 自定义层编写指南

示例:BatchNormalization 层有 2 个可训练权重和 2 个非可训练权重

layer = keras.layers.BatchNormalization()
layer.build((None, 4))  # 创建权重

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 4
trainable_weights: 2
non_trainable_weights: 2

层和模型还具有布尔属性 trainable。其值可以更改。将 layer.trainable 设置为 False 会将所有层的权重从可训练转移到非可训练。这称为"冻结"层:被冻结的层的状态在训练期间不会更新(无论是使用 fit() 训练还是使用依赖 trainable_weights 应用梯度更新的任何自定义循环)。

示例:将 trainable 设置为 False

layer = keras.layers.Dense(3)
layer.build((None, 4))  # 创建权重
layer.trainable = False  # 冻结层

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
权重: 2
可训练权重: 0
非可训练权重: 2

当一个可训练权重变为非可训练时,在训练过程中它的值不再被更新。

# 创建一个包含2层的模型
layer1 = keras.layers.Dense(3, activation="relu")
layer2 = keras.layers.Dense(3, activation="sigmoid")
model = keras.Sequential([keras.Input(shape=(3,)), layer1, layer2])

# 冻结第一层
layer1.trainable = False

# 保留layer1权重的副本以便后续参考
initial_layer1_weights_values = layer1.get_weights()

# 训练模型
model.compile(optimizer="adam", loss="mse")
model.fit(np.random.random((2, 3)), np.random.random((2, 3)))

# 检查layer1的权重在训练期间是否未改变
final_layer1_weights_values = layer1.get_weights()
np.testing.assert_allclose(
    initial_layer1_weights_values[0], final_layer1_weights_values[0]
)
np.testing.assert_allclose(
    initial_layer1_weights_values[1], final_layer1_weights_values[1]
)
 1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 766ms/step - loss: 0.0615

不要将layer.trainable属性与layer.__call__()中的参数training混淆(该参数控制层是否应在推理模式或训练模式下运行其向前通过)。有关更多信息,请参见 Keras FAQ


trainable属性的递归设置

如果在一个模型或具有子层的任何层上设置trainable = False,那么所有子层也都将变为非可训练。

示例:

inner_model = keras.Sequential(
    [
        keras.Input(shape=(3,)),
        keras.layers.Dense(3, activation="relu"),
        keras.layers.Dense(3, activation="relu"),
    ]
)

model = keras.Sequential(
    [
        keras.Input(shape=(3,)),
        inner_model,
        keras.layers.Dense(3, activation="sigmoid"),
    ]
)

model.trainable = False  # 冻结外层模型

assert inner_model.trainable == False  # `model`中的所有层现在都被冻结
assert inner_model.layers[0].trainable == False  # `trainable`是递归传播的

典型的迁移学习工作流程

这使我们了解到如何在Keras中实现典型的迁移学习工作流程:

  1. 实例化一个基本模型并加载预训练权重。
  2. 通过设置trainable = False来冻结基本模型中的所有层。
  3. 在基本模型的一个(或多个)层的输出上创建一个新模型。
  4. 在新数据集上训练新模型。

请注意,另一种更轻量级的工作流程也可以是:

  1. 实例化一个基本模型并加载预训练权重。
  2. 将新数据集通过模型运行,并记录基本模型中一个(或多个)层的输出。这被称为特征提取
  3. 将该输出用作新较小模型的输入数据。

第二种工作流程的一个关键优势是,您只需在数据上运行一次基本模型,而不是每个训练周期运行一次。因此它更快且更省钱。

然而,第二种工作流程的问题在于,它不允许您在训练期间动态修改新模型的输入数据,例如在进行数据增强时。迁移学习通常用于当新数据集的数据量太少,无法从头训练一个完整规模的模型的任务,而在这种情况下,数据增强非常重要。因此,在接下来的部分中,我们将重点关注第一种工作流程。

以下是在Keras中第一种工作流程的示例:

首先,实例化一个带有预训练权重的基本模型。

base_model = keras.applications.Xception(
    weights='imagenet',  # 加载在ImageNet上预训练的权重。
    input_shape=(150, 150, 3),
    include_top=False)  # 不包含顶部的ImageNet分类器。

然后,冻结基本模型。

base_model.trainable = False

在其上创建一个新模型。

inputs = keras.Input(shape=(150, 150, 3))
# 我们确保基本模型在这里以推理模式运行,
# 通过传递`training=False`。这对于微调很重要,正如您将在接下来的段落中学习到的。
x = base_model(inputs, training=False)
# 将形状为`base_model.output_shape[1:]`的特征转换为向量
x = keras.layers.GlobalAveragePooling2D()(x)
# 一个带有单个单元的Dense分类器(二分类)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

在新数据上训练模型。

model.compile(optimizer=keras.optimizers.Adam(),
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])
model.fit(new_dataset, epochs=20, callbacks=..., validation_data=...)

微调

一旦您的模型在新数据上收敛,您可以尝试解冻全部或部分 the base model and retrain the whole model end-to-end with a very low learning rate.

这一步是可选的最后步骤,可能会给你带来增量改进。 它也可能导致快速过拟合——请注意这一点。

在冻结层的模型训练完全收敛后执行此步骤是至关重要的。如果将随机初始化的可训练层与保持预训练特征的可训练层混合,随机初始化的层将在训练期间引起非常大的梯度更新,从而摧毁你的预训练特征。

此阶段使用非常低的学习率也是至关重要的,因为 你正在比第一次训练时训练一个更大的模型,且数据集通常非常小。 因此,如果应用大权重更新,你很容易快速过拟合。在这里,你只想以增量方式重新调整预训练权重。

以下是如何实施整个基础模型的微调:

# 解冻基础模型
base_model.trainable = True

# 在您对任何内部层的 `trainable` 属性进行任何更改后重新编译模型很重要,以便您的更改
# 能够被考虑到
model.compile(optimizer=keras.optimizers.Adam(1e-5),  # 非常低的学习率
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])

# 端到端训练。小心在过拟合之前停止!
model.fit(new_dataset, epochs=10, callbacks=..., validation_data=...)

关于 compile()trainable 的重要说明

对模型调用 compile() 是为了“冻结”该模型的行为。 这意味着在模型编译时的 trainable 属性值应该在该模型的生命周期内保持不变, 直到再次调用 compile。因此,如果你更改任何 trainable 值,请确保再次在模型上调用 compile(),以便你的更改能够生效。

关于 BatchNormalization 层的重要说明

许多图像模型包含 BatchNormalization 层。该层在可想象的每种情况下都是特例。以下是一些需要记住的事项。

  • BatchNormalization 包含 2 个不可训练的权重,这些权重在训练期间会被更新。这些变量跟踪输入的均值和方差。
  • 当你设置 bn_layer.trainable = False 时,BatchNormalization 层将以推理模式运行,并且不会更新其均值和方差统计信息。这一般不适用于其他层,因为 权重可训练性和推理/训练模式是两个正交概念。 但在 BatchNormalization 层的情况下,这两者是绑定在一起的。
  • 当你解冻一个包含 BatchNormalization 层的模型进行微调时,应该通过在调用基础模型时传递 training=FalseBatchNormalization 层保持在推理模式。 否则,应用于不可训练权重的更新将突然破坏模型所学到的内容。

你将在本指南结束时的端到端示例中看到这种模式的具体应用。


端到端示例:在猫与狗的数据集上微调图像分类模型

为了巩固这些概念,让我们引导您通过一个具体的端到端迁移学习和微调示例。我们将加载在 ImageNet 上预训练的 Xception 模型,并将其用于 Kaggle 的“猫与狗”分类数据集。

获取数据

首先,让我们通过 TFDS 获取猫与狗的数据集。如果你有自己的数据集,你可能会想使用实用工具 keras.utils.image_dataset_from_directory 从一组存储在类特定文件夹中的图像生成类似的有标签数据集对象。

迁移学习在处理非常小的数据集时最有用。为了保持我们数据集的小规模,我们将使用原始训练数据的 40%(25,000 张图像)进行 训练,10% 用于验证,10% 用于测试。

tfds.disable_progress_bar()

train_ds, validation_ds, test_ds = tfds.load(
    "cats_vs_dogs",
    # 保留 10% 用于验证和 10% 用于测试
    split=["train[:40%]", "train[40%:50%]", "train[50%:60%]"],
    as_supervised=True,  # 包含标签
)

print(f"训练样本数量: {train_ds.cardinality()}")
print(f"验证样本数量: {validation_ds.cardinality()}")
print(f"测试样本数量: {test_ds.cardinality()}")
 下载数据集 786.68 MiB (下载: 786.68 MiB, 生成: 未知大小, 总计: 786.68 MiB) 到 /home/mattdangerw/tensorflow_datasets/cats_vs_dogs/4.0.0...

警告:absl:1738 张图片已损坏并被跳过

 数据集 cats_vs_dogs 已下载并准备好在 /home/mattdangerw/tensorflow_datasets/cats_vs_dogs/4.0.0. 后续调用将重用此数据。
训练样本数量: 9305
验证样本数量: 2326
测试样本数量: 2326

这是训练数据集中的前 9 张图像 – 如你所见,它们的大小各不相同。

plt.figure(figsize=(10, 10))
for i, (image, label) in enumerate(train_ds.take(9)):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(image)
    plt.title(int(label))
    plt.axis("off")

png

我们也可以看到,标签 1 是“狗”,标签 0 是“猫”。

数据标准化

我们的原始图像大小各异。此外,每个像素由 0 到 255 之间的 3 个整数值(RGB 级别值)组成。这并不适合用于神经网络。我们需要做两件事:

  • 标准化为固定的图像大小。我们选择 150x150。
  • 将像素值归一化到 -1 和 1 之间。我们将使用一个 Normalization 层作为模型的一部分来做到这一点。

一般来说,开发以原始数据为输入的模型是个好习惯,而不是处理过的数据。理由是,如果你的模型期望处理过的数据,每当你导出模型以便在其他地方使用(在网络浏览器中,在移动应用程序中),你需要重新实现完全相同的预处理管道。这会迅速变得非常复杂。因此,我们应该在模型运行之前尽量减少预处理的数量。

在这里,我们将在数据管道中进行图像调整大小(因为深度神经网络只能处理连续的数据批次),并在创建模型时进行输入值缩放。

让我们将图像调整为 150x150:

resize_fn = keras.layers.Resizing(150, 150)

train_ds = train_ds.map(lambda x, y: (resize_fn(x), y))
validation_ds = validation_ds.map(lambda x, y: (resize_fn(x), y))
test_ds = test_ds.map(lambda x, y: (resize_fn(x), y))

使用随机数据增强

当你没有大型图像数据集时,通过对训练图像进行随机但现实的变换来人工引入样本多样性是个好习惯,例如随机水平翻转或小范围随机旋转。这有助于使模型接触到训练数据的不同方面,同时减缓过拟合。

augmentation_layers = [
    layers.RandomFlip("horizontal"),
    layers.RandomRotation(0.1),
]


def data_augmentation(x):
    for layer in augmentation_layers:
        x = layer(x)
    return x


train_ds = train_ds.map(lambda x, y: (data_augmentation(x), y))

让我们对数据进行批处理,并使用预取来优化加载速度。

from tensorflow import data as tf_data

batch_size = 64

train_ds = train_ds.batch(batch_size).prefetch(tf_data.AUTOTUNE).cache()
validation_ds = validation_ds.batch(batch_size).prefetch(tf_data.AUTOTUNE).cache()
test_ds = test_ds.batch(batch_size).prefetch(tf_data.AUTOTUNE).cache()

让我们可视化第一批中的第一张图像在各种随机变换后的样子:

for images, labels in train_ds.take(1):
    plt.figure(figsize=(10, 10))
    first_image = images[0]
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        augmented_image = data_augmentation(np.expand_dims(first_image, 0))
        plt.imshow(np.array(augmented_image[0]).astype("int32"))
        plt.title(int(labels[0]))
        plt.axis("off")

png


构建模型

现在让我们构建一个遵循之前解释的蓝图的模型。

请注意:

  • 我们添加了一个 Rescaling 层,将输入值(最初在 [0, 255] 范围内)缩放到 [-1, 1] 范围内。
  • 我们在分类层之前添加了一个 Dropout 层,以进行正则化。
  • 我们确保在调用基础模型时传递 training=False,以便它以推理模式运行,这样即使我们解冻基础模型进行微调时,batchnorm 统计信息也不会被更新。
base_model = keras.applications.Xception(
    weights="imagenet",  # 加载在 ImageNet 上预训练的权重。
    input_shape=(150, 150, 3),
    include_top=False,
)  # 不包括顶部的 ImageNet 分类器。

# 冻结基础模型
base_model.trainable = False

# 在顶部创建新模型
inputs = keras.Input(shape=(150, 150, 3))

# 预训练的 Xception 权重要求输入在 (-1., +1.) 范围内
# rescaling 层输出: `(inputs * scale) + offset`
scale_layer = keras.layers.Rescaling(scale=1 / 127.5, offset=-1)
x = scale_layer(inputs)

# 基础模型包含 batchnorm 层。我们希望在解冻基础模型进行微调时保持它们处于推理模式,
# 因此我们确保基础模型在此时以推理模式运行。
x = base_model(x, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dropout(0.2)(x)  # 使用 dropout 进行正则化
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

model.summary(show_trainable=True)
从 https://storage.googleapis.com/tensorflow/keras-applications/xception/xception_weights_tf_dim_ordering_tf_kernels_notop.h5 下载数据
 83683744/83683744 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step
模型: "functional_4"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━┓
┃ 层 (类型)                  输出形状                参数 #  可训练… ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━┩
│ input_layer_4 (输入层)  │ (None, 150, 150, 3)      │       0-   │
├─────────────────────────────┼──────────────────────────┼─────────┼───────┤
│ rescaling (重标定)       │ (None, 150, 150, 3)      │       0-   │
├─────────────────────────────┼──────────────────────────┼─────────┼───────┤
│ xception (功能性)       │ (None, 5, 5, 2048)       │ 20,861…N   │
├─────────────────────────────┼──────────────────────────┼─────────┼───────┤
│ global_average_pooling2d    │ (None, 2048)             │       0-   │
│ (全局平均池化2D)    │                          │         │       │
├─────────────────────────────┼──────────────────────────┼─────────┼───────┤
│ dropout (丢弃)           │ (None, 2048)             │       0-   │
├─────────────────────────────┼──────────────────────────┼─────────┼───────┤
│ dense_7 (密集)             │ (None, 1)                │   2,049Y   │
└─────────────────────────────┴──────────────────────────┴─────────┴───────┘
 总参数: 20,863,529 (79.59 MB)
 可训练参数: 2,049 (8.00 KB)
 非可训练参数: 20,861,480 (79.58 MB)

训练顶层

model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 2
print("正在拟合模型的顶层")
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
正在拟合模型的顶层
Epoch 1/2
  78/146 ━━━━━━━━━━━━━━━━━━━━  15s 226ms/step - binary_accuracy: 0.7995 - loss: 0.4088

Corrupt JPEG data: 65 extraneous bytes before marker 0xd9

 136/146 ━━━━━━━━━━━━━━━━━━━━  2s 231ms/step - binary_accuracy: 0.8430 - loss: 0.3298

Corrupt JPEG data: 239 extraneous bytes before marker 0xd9

 143/146 ━━━━━━━━━━━━━━━━━━━━  0s 231ms/step - binary_accuracy: 0.8464 - loss: 0.3235

Corrupt JPEG data: 1153 extraneous bytes before marker 0xd9

 144/146 ━━━━━━━━━━━━━━━━━━━━  0s 231ms/step - binary_accuracy: 0.8468 - loss: 0.3226

Corrupt JPEG data: 228 extraneous bytes before marker 0xd9

 146/146 ━━━━━━━━━━━━━━━━━━━━ 0s 260ms/step - binary_accuracy: 0.8478 - loss: 0.3209

Corrupt JPEG data: 2226 extraneous bytes before marker 0xd9

 146/146 ━━━━━━━━━━━━━━━━━━━━ 54s 317ms/step - binary_accuracy: 0.8482 - loss: 0.3200 - val_binary_accuracy: 0.9667 - val_loss: 0.0877
Epoch 2/2
 146/146 ━━━━━━━━━━━━━━━━━━━━ 7s 51ms/step - binary_accuracy: 0.9483 - loss: 0.1232 - val_binary_accuracy: 0.9705 - val_loss: 0.0786

<keras.src.callbacks.history.History at 0x7fc8b7f1db70>

对整个模型进行一轮微调

最后,让我们解冻基础模型并以较低的学习率训练整个模型。

重要的是,尽管基础模型变为可训练,但由于我们在构建模型时调用它时传递了training=False,因此它仍在推理模式下运行。这意味着内部的批量归一化层不会更新其批量统计数据。如果它们确实更新,它们将破坏到目前为止模型学习到的表示。

# 解冻基础模型。请注意,由于我们在调用它时传递了 `training=False`,
# 它仍在推理模式下运行。这意味着批归一化层不会更新其批量统计数据。
# 这可以防止批归一化层撤销我们到目前为止所做的所有训练。
base_model.trainable = True
model.summary(show_trainable=True)

model.compile(
    optimizer=keras.optimizers.Adam(1e-5),  # 较低的学习率
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 1
print("正在拟合整个模型")
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
模型: "functional_4"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━┓
┃ 层 (类型)                 输出形状              参数 #  可训练… ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━┩
│ input_layer_4 (输入层)  │ (None, 150, 150, 3)      │       0-   │
├─────────────────────────────┼──────────────────────────┼─────────┼───────┤
│ rescaling (重缩放)       │ (None, 150, 150, 3)      │       0-   │
├─────────────────────────────┼──────────────────────────┼─────────┼───────┤
│ xception (功能性)       │ (, 5, 5, 2048)       │ 20,861…Y   │
├─────────────────────────────┼──────────────────────────┼─────────┼───────┤
│ global_average_pooling2d    │ (, 2048)             │       0-   │
│ (全局平均池化)    │                          │         │       │
├─────────────────────────────┼──────────────────────────┼─────────┼───────┤
│ dropout (丢弃)           │ (, 2048)             │       0-   │
├─────────────────────────────┼──────────────────────────┼─────────┼───────┤
│ dense_7 (密集)             │ (, 1)                │   2,049Y   │
└─────────────────────────────┴──────────────────────────┴─────────┴───────┘
 总参数: 20,867,629 (79.60 MB)
 可训练参数: 20,809,001 (79.38 MB)
 不可训练参数: 54,528 (213.00 KB)
 优化器参数: 4,100 (16.02 KB)
拟合端到端模型
 146/146 ━━━━━━━━━━━━━━━━━━━━ 75s 327ms/step - binary_accuracy: 0.8487 - loss: 0.3760 - val_binary_accuracy: 0.9494 - val_loss: 0.1160

<keras.src.callbacks.history.History at 0x7fcd1c755090>

经过10个epoch的微调,我们在这里取得了不错的改善。 让我们在测试数据集上评估模型:

print("测试数据集评估")
model.evaluate(test_ds)
测试数据集评估
 11/37 ━━━━━━━━━━━━━━━━━━━━  1s 52ms/step - binary_accuracy: 0.9407 - loss: 0.1155

腐坏的JPEG数据: 99个额外字节在标记0xd9之前

 37/37 ━━━━━━━━━━━━━━━━━━━━ 2s 47ms/step - binary_accuracy: 0.9427 - loss: 0.1259

[0.13755160570144653, 0.941300630569458]