作者: Sayak Paul
创建日期: 2021/08/01
最后修改日期: 2021/08/01
描述: 通过知识蒸馏和函数匹配训练更好的学生模型。
知识蒸馏 (Hinton et al.) 是一种技术,可以将较大的模型压缩成较小的模型。这使我们能够获得高性能较大模型的好处,同时降低存储和内存成本,并实现更高的推理速度:
在知识蒸馏:好老师是耐心和一致的中,Beyer et al. 探索了各种现有的知识蒸馏方法,并证明它们都会导致次优表现。因此,实践者在开发资源受限的生产系统时,通常会选择其他替代方案(量化、剪枝、权重聚类等)。
Beyer et al. 研究了如何改善从知识蒸馏过程中得到的学生模型,使其能够始终匹配其教师模型的性能。在这个例子中,我们将研究他们提出的食谱,使用Flowers102 数据集。作为参考,通过这些食谱,作者能够生成一个在 ImageNet-1k 数据集上达到 82.8% 准确率的 ResNet50 模型。
如果您需要对知识蒸馏进行复习,并想了解它在 Keras 中的实现,可以参考这个例子。您还可以关注这个例子,它展示了知识蒸馏扩展到一致性训练的应用。
要跟随这个例子,您需要 TensorFlow 2.5 或更高版本以及 TensorFlow Addons,可以使用以下命令安装:
!pip install -q tensorflow-addons
from tensorflow import keras
import tensorflow_addons as tfa
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import tensorflow_datasets as tfds
tfds.disable_progress_bar()
AUTO = tf.data.AUTOTUNE # 用于动态调整并行性。
BATCH_SIZE = 64
# 来自表 4 和“训练设置”部分。
TEMPERATURE = 10 # 用于在它们进入 softmax 之前软化 logits。
INIT_LR = 0.003 # 初始学习率,将在训练期间逐渐降低。
WEIGHT_DECAY = 0.001 # 用于正则化。
CLIP_THRESHOLD = 1.0 # 用于按 L2 范数剪辑梯度。
# 我们将首先将训练图像调整为更大的大小,然后取随机裁剪的更小的大小。
BIGGER = 160
RESIZE = 128
train_ds, validation_ds, test_ds = tfds.load(
"oxford_flowers102", split=["train", "validation", "test"], as_supervised=True
)
print(f"训练示例的数量: {train_ds.cardinality()}.")
print(
f"验证示例的数量: {validation_ds.cardinality()}."
)
print(f"测试示例的数量: {test_ds.cardinality()}.")
训练示例的数量: 1020.
验证示例的数量: 1020.
测试示例的数量: 6149.
与任何蒸馏技术一样,首先训练一个表现良好的教师模型是重要的,该模型通常比后续的学生模型更大。作者将一个 BiT ResNet152x2 模型(教师)蒸馏到一个 BiT ResNet50 模型(学生)。
BiT 代表大迁移,最早在大迁移(BiT):通用视觉表示学习中提出。BiT 变体的 ResNet 使用组归一化 (Wu et al.) 和权重标准化 (Qiao et al.) 来替代批量归一化 (Ioffe et al.)。为了限制运行此示例所花费的时间,我们将使用一个已经在 Flowers102 数据集上训练好的 BiT ResNet101x3。您可以参考这个笔记本 了解更多关于训练过程的信息。该模型在Flowers102的测试集上达到了98.18%的准确率。
模型的权重被托管在Kaggle上作为一个数据集。 要下载权重,请按照以下步骤操作:
kaggle.json
,该文件包含您的API凭据。现在运行以下代码:
import os
os.environ["KAGGLE_USERNAME"] = "" # TODO: 在这里输入您的Kaggle用户名
os.environ["KAGGLE_KEY"] = "" # TODO: 在这里输入您的Kaggle密钥
一旦环境变量设置好,运行:
$ kaggle datasets download -d spsayakpaul/bitresnet101x3flowers102
$ unzip -qq bitresnet101x3flowers102.zip
这应该会生成一个名为T-r101x3-128
的文件夹,这实际上是一个教师
SavedModel
。
import os
os.environ["KAGGLE_USERNAME"] = "" # TODO: 在这里输入您的Kaggle用户名
os.environ["KAGGLE_KEY"] = "" # TODO: 在这里输入您的Kaggle API密钥
!kaggle datasets download -d spsayakpaul/bitresnet101x3flowers102
!unzip -qq bitresnet101x3flowers102.zip
# 由于教师模型不会进一步训练,我们将其设为
# 不可训练。
teacher_model = keras.models.load_model(
"/home/jupyter/keras-io/examples/keras_recipes/T-r101x3-128"
)
teacher_model.trainable = False
teacher_model.summary()
Model: "my_bi_t_model_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_1 (Dense) multiple 626790
_________________________________________________________________
keras_layer_1 (KerasLayer) multiple 381789888
=================================================================
Total params: 382,416,678
Trainable params: 0
Non-trainable params: 382,416,678
_________________________________________________________________
为了训练一个高质量的学生模型,作者建议对学生训练工作流程进行以下更改:
alpha
参数来完成,而不是从贝塔分布中获取。这里使用MixUp是为了帮助学生模型捕捉教师模型的底层函数。MixUp在线性插值不同样本之间在数据流形上。因此,这里的理由是如果学生被训练以适应它,那么它应该能够更好地匹配教师模型。为了包含更多的不变性,MixUp与“Inception风格”的裁剪结合使用
(Szegedy et al.)。这就是“函数匹配”术语出现在
原始论文中的原因。总之,在训练学生模型时,必须保持一致和耐心。
def mixup(images, labels):
alpha = tf.random.uniform([], 0, 1)
mixedup_images = alpha * images + (1 - alpha) * tf.reverse(images, axis=[0])
# 标签在这里不重要,因为它们在训练期间未被使用。
return mixedup_images, labels
def preprocess_image(image, label, train=True):
image = tf.cast(image, tf.float32) / 255.0
if train:
image = tf.image.resize(image, (BIGGER, BIGGER))
image = tf.image.random_crop(image, (RESIZE, RESIZE, 3))
image = tf.image.random_flip_left_right(image)
else:
# 中心裁剪比例来自这里:
# https://git.io/J8Kda.
image = tf.image.central_crop(image, central_fraction=0.875)
image = tf.image.resize(image, (RESIZE, RESIZE))
return image, label
def prepare_dataset(dataset, train=True, batch_size=BATCH_SIZE):
if train:
dataset = dataset.map(preprocess_image, num_parallel_calls=AUTO)
dataset = dataset.shuffle(BATCH_SIZE * 10)
else:
dataset = dataset.map(
lambda x, y: (preprocess_image(x, y, train)), num_parallel_calls=AUTO
)
dataset = dataset.batch(batch_size)
if train:
dataset = dataset.map(mixup, num_parallel_calls=AUTO)
dataset = dataset.prefetch(AUTO)
return dataset
注意为了简洁,我们在训练集上使用了轻度裁剪,但实际上应该应用“启发式”预处理。您可以参考 这个脚本 以获取更详细的实现。此外,真实标签不用于训练学生模型。
train_ds = prepare_dataset(train_ds, True)
validation_ds = prepare_dataset(validation_ds, False)
test_ds = prepare_dataset(test_ds, False)
sample_images, _ = next(iter(train_ds))
plt.figure(figsize=(10, 10))
for n in range(25):
ax = plt.subplot(5, 5, n + 1)
plt.imshow(sample_images[n].numpy())
plt.axis("off")
plt.show()
为了这个示例,我们将使用标准的 ResNet50V2 (He et al.)。
def get_resnetv2():
resnet_v2 = keras.applications.ResNet50V2(
weights=None,
input_shape=(RESIZE, RESIZE, 3),
classes=102,
classifier_activation="linear",
)
return resnet_v2
get_resnetv2().count_params()
23773798
与教师模型相比,该模型减少了358百万个参数。
我们将重用 这个示例 中的一些代码来进行知识蒸馏。
class Distiller(tf.keras.Model):
def __init__(self, student, teacher):
super().__init__()
self.student = student
self.teacher = teacher
self.loss_tracker = keras.metrics.Mean(name="distillation_loss")
@property
def metrics(self):
metrics = super().metrics
metrics.append(self.loss_tracker)
return metrics
def compile(
self, optimizer, metrics, distillation_loss_fn, temperature=TEMPERATURE,
):
super().compile(optimizer=optimizer, metrics=metrics)
self.distillation_loss_fn = distillation_loss_fn
self.temperature = temperature
def train_step(self, data):
# 解包数据
x, _ = data
# 教师的前向传播
teacher_predictions = self.teacher(x, training=False)
with tf.GradientTape() as tape:
# 学生的前向传播
student_predictions = self.student(x, training=True)
# 计算损失
distillation_loss = self.distillation_loss_fn(
tf.nn.softmax(teacher_predictions / self.temperature, axis=1),
tf.nn.softmax(student_predictions / self.temperature, axis=1),
)
# 计算梯度
trainable_vars = self.student.trainable_variables
gradients = tape.gradient(distillation_loss, trainable_vars)
# 更新权重
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
# 报告进度
self.loss_tracker.update_state(distillation_loss)
return {"distillation_loss": self.loss_tracker.result()}
def test_step(self, data):
# 解包数据
x, y = data
# 前向传播
teacher_predictions = self.teacher(x, training=False)
student_predictions = self.student(x, training=False)
# 计算损失
distillation_loss = self.distillation_loss_fn(
tf.nn.softmax(teacher_predictions / self.temperature, axis=1),
tf.nn.softmax(student_predictions / self.temperature, axis=1),
)
# 报告进度
self.loss_tracker.update_state(distillation_loss)
self.compiled_metrics.update_state(y, student_predictions)
results = {m.name: m.result() for m in self.metrics}
return results
论文中使用了一种预热余弦学习率调整。这种调整对于许多预训练方法,尤其是计算机视觉领域,是典型的。
# Some code is taken from:
# https://www.kaggle.com/ashusma/training-rfcx-tensorflow-tpu-effnet-b2.
class WarmUpCosine(keras.optimizers.schedules.LearningRateSchedule):
def __init__(
self, learning_rate_base, total_steps, warmup_learning_rate, warmup_steps
):
super().__init__()
self.learning_rate_base = learning_rate_base
self.total_steps = total_steps
self.warmup_learning_rate = warmup_learning_rate
self.warmup_steps = warmup_steps
self.pi = tf.constant(np.pi)
def __call__(self, step):
if self.total_steps < self.warmup_steps:
raise ValueError("Total_steps must be larger or equal to warmup_steps.")
cos_annealed_lr = tf.cos(
self.pi
* (tf.cast(step, tf.float32) - self.warmup_steps)
/ float(self.total_steps - self.warmup_steps)
)
learning_rate = 0.5 * self.learning_rate_base * (1 + cos_annealed_lr)
if self.warmup_steps > 0:
if self.learning_rate_base < self.warmup_learning_rate:
raise ValueError(
"Learning_rate_base must be larger or equal to "
"warmup_learning_rate."
)
slope = (
self.learning_rate_base - self.warmup_learning_rate
) / self.warmup_steps
warmup_rate = slope * tf.cast(step, tf.float32) + self.warmup_learning_rate
learning_rate = tf.where(
step < self.warmup_steps, warmup_rate, learning_rate
)
return tf.where(
step > self.total_steps, 0.0, learning_rate, name="learning_rate"
)
我们现在可以绘制一个使用此计划生成的学习率图。
ARTIFICIAL_EPOCHS = 1000
ARTIFICIAL_BATCH_SIZE = 512
DATASET_NUM_TRAIN_EXAMPLES = 1020
TOTAL_STEPS = int(
DATASET_NUM_TRAIN_EXAMPLES / ARTIFICIAL_BATCH_SIZE * ARTIFICIAL_EPOCHS
)
scheduled_lrs = WarmUpCosine(
learning_rate_base=INIT_LR,
total_steps=TOTAL_STEPS,
warmup_learning_rate=0.0,
warmup_steps=1500,
)
lrs = [scheduled_lrs(step) for step in range(TOTAL_STEPS)]
plt.plot(lrs)
plt.xlabel("步骤", fontsize=14)
plt.ylabel("学习率 (LR)", fontsize=14)
plt.show()
原始论文使用至少1000个epochs和512的批量大小进行“函数匹配”。本示例的目的是展示实现该食谱的工作流程,而不是演示在全规模应用时的结果。然而,这些食谱可以转移到论文中的原始设置。如果您有兴趣了解更多信息,请参阅这个仓库。
optimizer = tfa.optimizers.AdamW(
weight_decay=WEIGHT_DECAY, learning_rate=scheduled_lrs, clipnorm=CLIP_THRESHOLD
)
student_model = get_resnetv2()
distiller = Distiller(student=student_model, teacher=teacher_model)
distiller.compile(
optimizer,
metrics=[keras.metrics.SparseCategoricalAccuracy()],
distillation_loss_fn=keras.losses.KLDivergence(),
temperature=TEMPERATURE,
)
history = distiller.fit(
train_ds,
steps_per_epoch=int(np.ceil(DATASET_NUM_TRAIN_EXAMPLES / BATCH_SIZE)),
validation_data=validation_ds,
epochs=30, # 这应该至少是1000个epochs。
)
student = distiller.student
student_model.compile(metrics=["准确率"])
_, top1_accuracy = student.evaluate(test_ds)
print(f"测试集上的Top-1准确率: {round(top1_accuracy * 100, 2)}%")
Epoch 1/30
16/16 [==============================] - 74s 3s/step - 蒸馏损失: 0.0070 - val_sparse_categorical_accuracy: 0.0039 - val_distillation_loss: 0.0061
Epoch 2/30
16/16 [==============================] - 37s 2s/step - 蒸馏损失: 0.0059 - val_sparse_categorical_accuracy: 0.0098 - val_distillation_loss: 0.0061
Epoch 3/30
16/16 [==============================] - 37s 2s/step - 蒸馏损失: 0.0049 - val_sparse_categorical_accuracy: 0.0098 - val_distillation_loss: 0.0060
Epoch 4/30
16/16 [==============================] - 37s 2s/step - 蒸馏损失: 0.0048 - val_sparse_categorical_accuracy: 0.0098 - val_distillation_loss: 0.0060
Epoch 5/30
16/16 [==============================] - 37s 2s/step - 蒸馏损失: 0.0043 - val_sparse_categorical_accuracy: 0.0098 - val_distillation_loss: 0.0060
Epoch 6/30
16/16 [==============================] - 37s 2s/step - 蒸馏损失: 0.0041 - val_sparse_categorical_accuracy: 0.0108 - val_distillation_loss: 0.0060
Epoch 7/30
16/16 [==============================] - 37s 2s/step - 蒸馏损失: 0.0038 - val_sparse_categorical_accuracy: 0.0098 - val_distillation_loss: 0.0061
Epoch 8/30
16/16 [==============================] - 37s 2s/step - 蒸馏损失: 0.0040 - val_sparse_categorical_accuracy: 0.0098 - val_distillation_loss: 0.0062
Epoch 9/30
16/16 [==============================] - 37s 2s/step - 蒸馏损失: 0.0039 - val_sparse_categorical_accuracy: 0.0098 - val_distillation_loss: 0.0063
Epoch 10/30
16/16 [==============================] - 37s 2s/step - 蒸馏损失: 0.0035 - val_sparse_categorical_accuracy: 0.0098 - val_distillation_loss: 0.0064
Epoch 11/30
16/16 [==============================] - 37s 2s/step - 蒸馏损失: 0.0041 - val_sparse_categorical_accuracy: 0.0098 - val_distillation_loss: 0.0064
Epoch 12/30
16/16 [==============================] - 37s 2s/step - 蒸馏损失: 0.0039 - val_sparse_categorical_accuracy: 0.0098 - val_distillation_loss: 0.0067
Epoch 13/30
16/16 [==============================] - 37s 2s/step - 蒸馏损失: 0.0039 - val_sparse_categorical_accuracy: 0.0098 - val_distillation_loss: 0.0067
Epoch 14/30
16/16 [==============================] - 37s 2s/step - 蒸馏损失: 0.0036 - val_sparse_categorical_accuracy: 0.0098 - val_distillation_loss: 0.0066
Epoch 15/30
16/16 [==============================] - 37s 2s/step - 蒸馏损失: 0.0037 - val_sparse_categorical_accuracy: 0.0098 - val_distillation_loss: 0.0065
Epoch 16/30
16/16 [==============================] - 37s 2s/step - 蒸馏损失: 0.0038 - val_sparse_categorical_accuracy: 0.0098 - val_distillation_loss: 0.0068
Epoch 17/30
16/16 [==============================] - 37s 2s/step - 蒸馏损失: 0.0039 - val_sparse_categorical_accuracy: 0.0098 - val_distillation_loss: 0.0066
Epoch 18/30
16/16 [==============================] - 37s 2s/step - 蒸馏损失: 0.0038 - val_sparse_categorical_accuracy: 0.0098 - val_distillation_loss: 0.0064
Epoch 19/30
16/16 [==============================] - 37s 2s/step - 蒸馏损失: 0.0035 - val_sparse_categorical_accuracy: 0.0098 - val_distillation_loss: 0.0071
Epoch 20/30
16/16 [==============================] - 37s 2s/step - 蒸馏损失: 0.0038 - val_sparse_categorical_accuracy: 0.0098 - val_distillation_loss: 0.0066
Epoch 21/30
16/16 [==============================] - 37s 2s/step - 蒸馏损失: 0.0038 - val_sparse_categorical_accuracy: 0.0098 - val_distillation_loss: 0.0068
Epoch 22/30
16/16 [==============================] - 37s 2s/step - 蒸馏损失: 0.0034 - val_sparse_categorical_accuracy: 0.0098 - val_distillation_loss: 0.0073
Epoch 23/30
16/16 [==============================] - 37s 2s/step - 蒸馏损失: 0.0035 - val_sparse_categorical_accuracy: 0.0098 - val_distillation_loss: 0.0078
Epoch 24/30
16/16 [==============================] - 37s 2s/step - 蒸馏损失: 0.0037 - val_sparse_categorical_accuracy: 0.0098 - val_distillation_loss: 0.0087
Epoch 25/30
16/16 [==============================] - 37s 2s/step - 蒸馏损失: 0.0031 - val_sparse_categorical_accuracy: 0.0108 - val_distillation_loss: 0.0078
Epoch 26/30
16/16 [==============================] - 37s 2s/step - 蒸馏损失: 0.0033 - val_sparse_categorical_accuracy: 0.0098 - val_distillation_loss: 0.0072
Epoch 27/30
16/16 [==============================] - 37s 2s/step - 蒸馏损失: 0.0036 - val_sparse_categorical_accuracy: 0.0098 - val_distillation_loss: 0.0071
Epoch 28/30
16/16 [==============================] - 37s 2s/step - 蒸馏损失: 0.0036 - val_sparse_categorical_accuracy: 0.0275 - val_distillation_loss: 0.0078
Epoch 29/30
16/16 [==============================] - 37s 2s/step - 蒸馏损失: 0.0032 - val_sparse_categorical_accuracy: 0.0196 - val_distillation_loss: 0.0068
Epoch 30/30
16/16 [==============================] - 37s 2s/step - 蒸馏损失: 0.0034 - val_sparse_categorical_accuracy: 0.0147 - val_distillation_loss: 0.0071
97/97 [==============================] - 7s 64ms/step - 损失: 0.0000e+00 - 准确率: 0.0107
测试集的Top-1准确率: 1.07%
仅用30个训练周期,结果远未达到预期。这就是耐心的好处,也就是更长的训练时间将发挥作用。让我们调查一下训练了1000个周期的模型能够做到什么。
# 下载预训练权重。
!wget https://git.io/JBO3Y -O S-r50x1-128-1000.tar.gz
!tar xf S-r50x1-128-1000.tar.gz
pretrained_student = keras.models.load_model("S-r50x1-128-1000")
pretrained_student.summary()
模型: "resnet"
_________________________________________________________________
层 (类型) 输出形状 参数 #
=================================================================
root_block (Sequential) (None, 32, 32, 64) 9408
_________________________________________________________________
block1 (Sequential) (None, 32, 32, 256) 214912
_________________________________________________________________
block2 (Sequential) (None, 16, 16, 512) 1218048
_________________________________________________________________
block3 (Sequential) (None, 8, 8, 1024) 7095296
_________________________________________________________________
block4 (Sequential) (None, 4, 4, 2048) 14958592
_________________________________________________________________
group_norm (GroupNormalizati multiple 4096
_________________________________________________________________
re_lu_97 (ReLU) multiple 0
_________________________________________________________________
global_average_pooling2d_1 ( multiple 0
_________________________________________________________________
head/dense (Dense) multiple 208998
=================================================================
总参数: 23,709,350
可训练参数: 23,709,350
不可训练参数: 0
_________________________________________________________________
该模型完全遵循作者在其学生模型中使用的结构。这就是模型摘要略有不同的原因。
_, top1_accuracy = pretrained_student.evaluate(test_ds)
print(f"测试集上的 Top-1 准确率: {round(top1_accuracy * 100, 2)}%")
97/97 [==============================] - 14s 131ms/step - loss: 0.0000e+00 - accuracy: 0.8102
测试集上的 Top-1 准确率: 81.02%
经过100000个训练周期,这个模型的Top-1准确率达到95.54%。
论文中提出了一些重要的消融研究,展示了这些方法相较于以前技术的有效性。因此,如果您对这些方法持怀疑态度,务必查阅论文。
借助基于TPU的硬件基础设施,我们可以更快地将模型训练1000个周期。这甚至不需要对这个代码库进行很多更改。我们鼓励您查看 这个仓库,因为它提供了这些方法的TPU兼容训练工作流程,并可以在利用他们的免费TPU v3-8硬件的 Kaggle Kernel 上运行。