作者: Kenneth Borup
创建日期: 2020/09/01
最后修改日期: 2020/09/01
描述: 经典知识蒸馏的实现。
知识蒸馏是一种模型压缩的方法,其中一个小的(学生)模型被训练以匹配一个大型的预训练(教师)模型。知识通过最小化损失函数从教师模型转移到学生模型,该损失函数旨在匹配软化的教师logits以及真实标签。
通过在softmax中应用“温度”缩放函数来软化logits,有效地平滑了概率分布并揭示了教师学习的类间关系。
参考文献:
import os
import keras
from keras import layers
from keras import ops
import numpy as np
Distiller()
类自定义的 Distiller()
类重写了 Model
的 compile
、compute_loss
和 call
方法。为了使用蒸馏器,我们需要:
temperature
alpha
因子在 compute_loss
方法中,我们执行教师和学生的前向传播,计算损失,并分别用 alpha
和 1 - alpha
对 student_loss
和 distillation_loss
进行加权。注意:只有学生权重会被更新。
class Distiller(keras.Model):
def __init__(self, student, teacher):
super().__init__()
self.teacher = teacher
self.student = student
def compile(
self,
optimizer,
metrics,
student_loss_fn,
distillation_loss_fn,
alpha=0.1,
temperature=3,
):
"""配置蒸馏器。
参数:
optimizer: 学生权重的Keras优化器
metrics: 评估的Keras指标
student_loss_fn: 学生预测与真实标签之间差异的损失函数
distillation_loss_fn: 软学生预测与软教师预测之间差异的损失函数
alpha: student_loss_fn的权重和1-alpha的distillation_loss_fn
temperature: 软化概率分布的温度。更大的温度会使分布更平滑。
"""
super().compile(optimizer=optimizer, metrics=metrics)
self.student_loss_fn = student_loss_fn
self.distillation_loss_fn = distillation_loss_fn
self.alpha = alpha
self.temperature = temperature
def compute_loss(
self, x=None, y=None, y_pred=None, sample_weight=None, allow_empty=False
):
teacher_pred = self.teacher(x, training=False)
student_loss = self.student_loss_fn(y, y_pred)
distillation_loss = self.distillation_loss_fn(
ops.softmax(teacher_pred / self.temperature, axis=1),
ops.softmax(y_pred / self.temperature, axis=1),
) * (self.temperature**2)
loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss
return loss
def call(self, x):
return self.student(x)
最初,我们创建一个教师模型和一个更小的学生模型。两个模型都是通过 Sequential()
创建的卷积神经网络,但可以是任何Keras模型。
teacher = keras.Sequential( [ keras.Input(shape=(28, 28, 1)), layers.Conv2D(256, (3, 3), strides=(2, 2), padding="same"), layers.LeakyReLU(negative_slope=0.2), layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"), layers.Conv2D(512, (3, 3), strides=(2, 2), padding="same"), layers.Flatten(), layers.Dense(10), ], name="teacher", )
student = keras.Sequential( [ keras.Input(shape=(28, 28, 1)), layers.Conv2D(16, (3, 3), strides=(2, 2), padding="same"), layers.LeakyReLU(negative_slope=0.2), layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"), layers.Conv2D(32, (3, 3), strides=(2, 2), padding="same"), layers.Flatten(), layers.Dense(10), ], name="student", )
用于训练教师和蒸馏教师的数据集是 MNIST,这个过程对于任何其他数据集,比如 CIFAR-10,在选择合适的模型时是等价的。教师和学生都在训练集上训练,并在测试集上评估。
# 准备训练和测试数据集。
batch_size = 64
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
# 数据归一化
x_train = x_train.astype("float32") / 255.0
x_train = np.reshape(x_train, (-1, 28, 28, 1))
x_test = x_test.astype("float32") / 255.0
x_test = np.reshape(x_test, (-1, 28, 28, 1))
在知识蒸馏中,我们假设教师已经训练好并固定了。因此,我们首先以通常的方式在训练集上训练教师模型。
# 像往常一样训练教师
teacher.compile(
optimizer=keras.optimizers.Adam(),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
# 在数据上训练并评估教师。
teacher.fit(x_train, y_train, epochs=5)
teacher.evaluate(x_test, y_test)
Epoch 1/5
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 8s 3ms/step - loss: 0.2408 - sparse_categorical_accuracy: 0.9259
Epoch 2/5
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 5s 3ms/step - loss: 0.0912 - sparse_categorical_accuracy: 0.9726
Epoch 3/5
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 7s 4ms/step - loss: 0.0758 - sparse_categorical_accuracy: 0.9777
Epoch 4/5
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 5s 3ms/step - loss: 0.0690 - sparse_categorical_accuracy: 0.9797
Epoch 5/5
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 5s 3ms/step - loss: 0.0582 - sparse_categorical_accuracy: 0.9825
313/313 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.0931 - sparse_categorical_accuracy: 0.9760
[0.09044107794761658, 0.978100061416626]
我们已经训练了教师模型,接下来只需要初始化一个
Distiller(student, teacher)
实例,使用所需的损失函数、
超参数和优化器进行 compile()
,然后将教师蒸馏到学生。
# 初始化并编译蒸馏器
distiller = Distiller(student=student, teacher=teacher)
distiller.compile(
optimizer=keras.optimizers.Adam(),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
distillation_loss_fn=keras.losses.KLDivergence(),
alpha=0.1,
temperature=10,
)
# 将教师蒸馏到学生
distiller.fit(x_train, y_train, epochs=3)
# 在测试数据集上评估学生
distiller.evaluate(x_test, y_test)
Epoch 1/3
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 8s 3ms/step - loss: 1.8752 - sparse_categorical_accuracy: 0.7357
Epoch 2/3
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 6s 3ms/step - loss: 0.0333 - sparse_categorical_accuracy: 0.9475
Epoch 3/3
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 6s 3ms/step - loss: 0.0223 - sparse_categorical_accuracy: 0.9621
313/313 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.0189 - sparse_categorical_accuracy: 0.9629
[0.017046602442860603, 0.969200074672699]
我们还可以从头开始训练一个等效的学生模型,而不使用教师,以便评估通过知识蒸馏获得的性能提升。
# 和往常一样训练学生
student_scratch.compile(
optimizer=keras.optimizers.Adam(),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
# 训练并评估从头训练的学生。
student_scratch.fit(x_train, y_train, epochs=3)
student_scratch.evaluate(x_test, y_test)
Epoch 1/3
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 4s 1ms/step - loss: 0.5111 - sparse_categorical_accuracy: 0.8460
Epoch 2/3
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 3s 1ms/step - loss: 0.1039 - sparse_categorical_accuracy: 0.9687
Epoch 3/3
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 3s 1ms/step - loss: 0.0748 - sparse_categorical_accuracy: 0.9780
313/313 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - loss: 0.0744 - sparse_categorical_accuracy: 0.9737
[0.0629437193274498, 0.9778000712394714]
如果教师训练了5个完整的训练周期,而学生基于该教师蒸馏训练了3个完整的训练周期,那么在这个例子中,你应该会体验到与从头训练同一个学生模型相比的性能提升,甚至与教师本身相比。你应该期望教师的准确率在97.6%左右,从头训练的学生应该在97.6%左右,而蒸馏后的学生应该在98.1%左右。移除或尝试不同的随机种子以使用不同的权重初始化。