代码示例 / 计算机视觉 / 知识蒸馏

知识蒸馏

作者: Kenneth Borup
创建日期: 2020/09/01
最后修改日期: 2020/09/01
描述: 经典知识蒸馏的实现。

在Colab中查看 GitHub源代码


知识蒸馏简介

知识蒸馏是一种模型压缩的方法,其中一个小的(学生)模型被训练以匹配一个大型的预训练(教师)模型。知识通过最小化损失函数从教师模型转移到学生模型,该损失函数旨在匹配软化的教师logits以及真实标签。

通过在softmax中应用“温度”缩放函数来软化logits,有效地平滑了概率分布并揭示了教师学习的类间关系。

参考文献:


设置

import os

import keras
from keras import layers
from keras import ops
import numpy as np

构建 Distiller()

自定义的 Distiller() 类重写了 Modelcompilecompute_losscall 方法。为了使用蒸馏器,我们需要:

  • 一个训练好的教师模型
  • 一个要训练的学生模型
  • 一个计算学生预测与真实标签之间差异的学生损失函数
  • 一个计算软学生预测与软教师标签之间差异的蒸馏损失函数,以及一个temperature
  • 一个权重学生损失和蒸馏损失的alpha因子
  • 一个优化器用于学生模型和(可选的)评估性能的指标

compute_loss 方法中,我们执行教师和学生的前向传播,计算损失,并分别用 alpha1 - alphastudent_lossdistillation_loss 进行加权。注意:只有学生权重会被更新。

class Distiller(keras.Model):
    def __init__(self, student, teacher):
        super().__init__()
        self.teacher = teacher
        self.student = student

    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
        alpha=0.1,
        temperature=3,
    ):
        """配置蒸馏器。

        参数:
            optimizer: 学生权重的Keras优化器
            metrics: 评估的Keras指标
            student_loss_fn: 学生预测与真实标签之间差异的损失函数
            distillation_loss_fn: 软学生预测与软教师预测之间差异的损失函数
            alpha: student_loss_fn的权重和1-alpha的distillation_loss_fn
            temperature: 软化概率分布的温度。更大的温度会使分布更平滑。
        """
        super().compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    def compute_loss(
        self, x=None, y=None, y_pred=None, sample_weight=None, allow_empty=False
    ):
        teacher_pred = self.teacher(x, training=False)
        student_loss = self.student_loss_fn(y, y_pred)

        distillation_loss = self.distillation_loss_fn(
            ops.softmax(teacher_pred / self.temperature, axis=1),
            ops.softmax(y_pred / self.temperature, axis=1),
        ) * (self.temperature**2)

        loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss
        return loss

    def call(self, x):
        return self.student(x)

创建学生和教师模型

最初,我们创建一个教师模型和一个更小的学生模型。两个模型都是通过 Sequential() 创建的卷积神经网络,但可以是任何Keras模型。

创建教师模型

teacher = keras.Sequential( [ keras.Input(shape=(28, 28, 1)), layers.Conv2D(256, (3, 3), strides=(2, 2), padding="same"), layers.LeakyReLU(negative_slope=0.2), layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"), layers.Conv2D(512, (3, 3), strides=(2, 2), padding="same"), layers.Flatten(), layers.Dense(10), ], name="teacher", )

创建学生模型

student = keras.Sequential( [ keras.Input(shape=(28, 28, 1)), layers.Conv2D(16, (3, 3), strides=(2, 2), padding="same"), layers.LeakyReLU(negative_slope=0.2), layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"), layers.Conv2D(32, (3, 3), strides=(2, 2), padding="same"), layers.Flatten(), layers.Dense(10), ], name="student", )

克隆学生模型以便于后续比较

student_scratch = keras.models.clone_model(student)

准备数据集

用于训练教师和蒸馏教师的数据集是 MNIST,这个过程对于任何其他数据集,比如 CIFAR-10,在选择合适的模型时是等价的。教师和学生都在训练集上训练,并在测试集上评估。

# 准备训练和测试数据集。
batch_size = 64
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# 数据归一化
x_train = x_train.astype("float32") / 255.0
x_train = np.reshape(x_train, (-1, 28, 28, 1))

x_test = x_test.astype("float32") / 255.0
x_test = np.reshape(x_test, (-1, 28, 28, 1))

训练教师

在知识蒸馏中,我们假设教师已经训练好并固定了。因此,我们首先以通常的方式在训练集上训练教师模型。

# 像往常一样训练教师
teacher.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

# 在数据上训练并评估教师。
teacher.fit(x_train, y_train, epochs=5)
teacher.evaluate(x_test, y_test)
Epoch 1/5
 1875/1875 ━━━━━━━━━━━━━━━━━━━━ 8s 3ms/step - loss: 0.2408 - sparse_categorical_accuracy: 0.9259
Epoch 2/5
 1875/1875 ━━━━━━━━━━━━━━━━━━━━ 5s 3ms/step - loss: 0.0912 - sparse_categorical_accuracy: 0.9726
Epoch 3/5
 1875/1875 ━━━━━━━━━━━━━━━━━━━━ 7s 4ms/step - loss: 0.0758 - sparse_categorical_accuracy: 0.9777
Epoch 4/5
 1875/1875 ━━━━━━━━━━━━━━━━━━━━ 5s 3ms/step - loss: 0.0690 - sparse_categorical_accuracy: 0.9797
Epoch 5/5
 1875/1875 ━━━━━━━━━━━━━━━━━━━━ 5s 3ms/step - loss: 0.0582 - sparse_categorical_accuracy: 0.9825
 313/313 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.0931 - sparse_categorical_accuracy: 0.9760

[0.09044107794761658, 0.978100061416626]

将教师蒸馏到学生

我们已经训练了教师模型,接下来只需要初始化一个 Distiller(student, teacher) 实例,使用所需的损失函数、 超参数和优化器进行 compile(),然后将教师蒸馏到学生。

# 初始化并编译蒸馏器
distiller = Distiller(student=student, teacher=teacher)
distiller.compile(
    optimizer=keras.optimizers.Adam(),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
    student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    distillation_loss_fn=keras.losses.KLDivergence(),
    alpha=0.1,
    temperature=10,
)

# 将教师蒸馏到学生
distiller.fit(x_train, y_train, epochs=3)

# 在测试数据集上评估学生
distiller.evaluate(x_test, y_test)
Epoch 1/3
 1875/1875 ━━━━━━━━━━━━━━━━━━━━ 8s 3ms/step - loss: 1.8752 - sparse_categorical_accuracy: 0.7357
Epoch 2/3
 1875/1875 ━━━━━━━━━━━━━━━━━━━━ 6s 3ms/step - loss: 0.0333 - sparse_categorical_accuracy: 0.9475
Epoch 3/3
 1875/1875 ━━━━━━━━━━━━━━━━━━━━ 6s 3ms/step - loss: 0.0223 - sparse_categorical_accuracy: 0.9621
 313/313 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.0189 - sparse_categorical_accuracy: 0.9629

[0.017046602442860603, 0.969200074672699]

从头开始训练学生以进行比较

我们还可以从头开始训练一个等效的学生模型,而不使用教师,以便评估通过知识蒸馏获得的性能提升。

# 和往常一样训练学生
student_scratch.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

# 训练并评估从头训练的学生。
student_scratch.fit(x_train, y_train, epochs=3)
student_scratch.evaluate(x_test, y_test)
Epoch 1/3
 1875/1875 ━━━━━━━━━━━━━━━━━━━━ 4s 1ms/step - loss: 0.5111 - sparse_categorical_accuracy: 0.8460
Epoch 2/3
 1875/1875 ━━━━━━━━━━━━━━━━━━━━ 3s 1ms/step - loss: 0.1039 - sparse_categorical_accuracy: 0.9687
Epoch 3/3
 1875/1875 ━━━━━━━━━━━━━━━━━━━━ 3s 1ms/step - loss: 0.0748 - sparse_categorical_accuracy: 0.9780
 313/313 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - loss: 0.0744 - sparse_categorical_accuracy: 0.9737

[0.0629437193274498, 0.9778000712394714]

如果教师训练了5个完整的训练周期,而学生基于该教师蒸馏训练了3个完整的训练周期,那么在这个例子中,你应该会体验到与从头训练同一个学生模型相比的性能提升,甚至与教师本身相比。你应该期望教师的准确率在97.6%左右,从头训练的学生应该在97.6%左右,而蒸馏后的学生应该在98.1%左右。移除或尝试不同的随机种子以使用不同的权重初始化。