作者: Rishit Dagli
创建日期: 06/18/21
最后修改: 07/25/23
描述: 实现梯度中心化以提高深度神经网络的训练性能。
此示例实现了 梯度中心化,这是 Yong 等人提出的一种用于深度神经网络的新优化技术,并在 Laurence Moroney 的 马或人 数据集上进行了演示。梯度中心化可以加速训练过程并提高 DNN 的最终泛化性能。它通过将梯度向量中心化以使均值为零来直接作用于梯度。此外,梯度中心化提高了损失函数及其梯度的利普希茨性,从而使训练过程变得更加高效和稳定。
此示例需要 tensorflow_datasets
,可以使用以下命令安装:
pip install tensorflow-datasets
from time import time
import keras
from keras import layers
from keras.optimizers import RMSprop
from keras import ops
from tensorflow import data as tf_data
import tensorflow_datasets as tfds
在此示例中,我们将使用 马或人 数据集。
num_classes = 2
input_shape = (300, 300, 3)
dataset_name = "horses_or_humans"
batch_size = 128
AUTOTUNE = tf_data.AUTOTUNE
(train_ds, test_ds), metadata = tfds.load(
name=dataset_name,
split=[tfds.Split.TRAIN, tfds.Split.TEST],
with_info=True,
as_supervised=True,
)
print(f"图像形状: {metadata.features['image'].shape}")
print(f"训练图像: {metadata.splits['train'].num_examples}")
print(f"测试图像: {metadata.splits['test'].num_examples}")
图像形状: (300, 300, 3)
训练图像: 1027
测试图像: 256
我们将数据重新缩放到 [0, 1]
并对数据进行简单的增强。
rescale = layers.Rescaling(1.0 / 255)
data_augmentation = [
layers.RandomFlip("horizontal_and_vertical"),
layers.RandomRotation(0.3),
layers.RandomZoom(0.2),
]
# 应用增强的辅助工具
def apply_aug(x):
for aug in data_augmentation:
x = aug(x)
return x
def prepare(ds, shuffle=False, augment=False):
# 重新缩放数据集
ds = ds.map(lambda x, y: (rescale(x), y), num_parallel_calls=AUTOTUNE)
if shuffle:
ds = ds.shuffle(1024)
# 批处理数据集
ds = ds.batch(batch_size)
# 仅在训练集上使用数据增强
if augment:
ds = ds.map(
lambda x, y: (apply_aug(x), y),
num_parallel_calls=AUTOTUNE,
)
# 使用缓冲预取
return ds.prefetch(buffer_size=AUTOTUNE)
重新缩放和增强数据
train_ds = prepare(train_ds, shuffle=True, augment=True)
test_ds = prepare(test_ds)
在本节中,我们将定义一个卷积神经网络。
model = keras.Sequential(
[
layers.Input(shape=input_shape),
layers.Conv2D(16, (3, 3), activation="relu"),
layers.MaxPooling2D(2, 2),
layers.Conv2D(32, (3, 3), activation="relu"),
layers.Dropout(0.5),
layers.MaxPooling2D(2, 2),
layers.Conv2D(64, (3, 3), activation="relu"),
layers.Dropout(0.5),
layers.MaxPooling2D(2, 2),
layers.Conv2D(64, (3, 3), activation="relu"),
layers.MaxPooling2D(2, 2),
layers.Conv2D(64, (3, 3), activation="relu"),
layers.MaxPooling2D(2, 2),
layers.Flatten(),
layers.Dropout(0.5),
layers.Dense(512, activation="relu"),
layers.Dense(1, activation="sigmoid"),
]
)
我们现在将子类化 RMSProp
优化器类,修改 keras.optimizers.Optimizer.get_gradients()
方法,在此实现梯度中心化。总体而言,假设我们通过反向传播获得了密集或卷积层的梯度,然后计算权重矩阵的列向量的均值,然后从每个列向量中移除均值。
在 这篇论文 中的实验显示, 应用程序,包括一般图像分类、细粒度图像分类、检测和分割以及人物再识别(Person ReID)显示,GC(梯度中心化)可以持续提高深度神经网络学习的性能。
此外,出于简单考虑,目前我们没有实现梯度裁剪功能,但这相对简单易实现。
目前我们只是为 RMSProp
优化器创建一个子类,但您可以轻松地为任何其他优化器或自定义优化器以相同的方式重现这一点。我们将在后面的章节中使用这个类来训练带有梯度中心化的模型。
class GCRMSprop(RMSprop):
def get_gradients(self, loss, params):
# 我们在这里仅提供一个修改过的get_gradients()函数,因
# 为我们只是尝试计算中心化梯度。
grads = []
gradients = super().get_gradients()
for grad in gradients:
grad_len = len(grad.shape)
if grad_len > 1:
axis = list(range(grad_len - 1))
grad -= ops.mean(grad, axis=axis, keep_dims=True)
grads.append(grad)
return grads
optimizer = GCRMSprop(learning_rate=1e-4)
我们还将创建一个回调,使我们能够轻松测量总训练时间和每个时期所需的时间,因为我们有兴趣比较梯度中心化对上述模型的影响。
class TimeHistory(keras.callbacks.Callback):
def on_train_begin(self, logs={}):
self.times = []
def on_epoch_begin(self, batch, logs={}):
self.epoch_time_start = time()
def on_epoch_end(self, batch, logs={}):
self.times.append(time() - self.epoch_time_start)
我们现在在不使用梯度中心化的情况下训练之前构建的模型,以便与使用梯度中心化训练的模型的训练性能进行比较。
time_callback_no_gc = TimeHistory()
model.compile(
loss="binary_crossentropy",
optimizer=RMSprop(learning_rate=1e-4),
metrics=["accuracy"],
)
model.summary()
模型: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓ ┃ 层 (类型) ┃ 输出形状 ┃ 参数 # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩ │ conv2d (卷积层) │ (None, 298, 298, 16) │ 448 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ max_pooling2d (最大池化层) │ (None, 149, 149, 16) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv2d_1 (卷积层) │ (None, 147, 147, 32) │ 4,640 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ dropout (Dropout) │ (None, 147, 147, 32) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ max_pooling2d_1 (MaxPooling2D) │ (None, 73, 73, 32) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv2d_2 (Conv2D) │ (None, 71, 71, 64) │ 18,496 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ dropout_1 (Dropout) │ (None, 71, 71, 64) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ max_pooling2d_2 (MaxPooling2D) │ (None, 35, 35, 64) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv2d_3 (Conv2D) │ (None, 33, 33, 64) │ 36,928 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ max_pooling2d_3 (MaxPooling2D) │ (None, 16, 16, 64) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv2d_4 (Conv2D) │ (None, 14, 14, 64) │ 36,928 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ max_pooling2d_4 (MaxPooling2D) │ (None, 7, 7, 64) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ flatten (Flatten) │ (None, 3136) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ dropout_2 (Dropout) │ (None, 3136) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ dense (Dense) │ (None, 512) │ 1,606,144 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ dense_1 (Dense) │ (None, 1) │ 513 │ └─────────────────────────────────┴───────────────────────────┴────────────┘
总参数: 1,704,097 (6.50 MB)
可训练参数: 1,704,097 (6.50 MB)
不可训练参数: 0 (0.00 B)
我们还保存了历史记录,因为我们稍后想比较使用和不使用梯度中心化训练的模型
history_no_gc = model.fit(
train_ds, epochs=10, verbose=1, callbacks=[time_callback_no_gc]
)
第 1/10 轮
9/9 ━━━━━━━━━━━━━━━━━━━━ 24s 778ms/step - 准确率: 0.4772 - 损失: 0.7405
第 2/10 轮
9/9 ━━━━━━━━━━━━━━━━━━━━ 10s 597ms/step - 准确率: 0.5434 - 损失: 0.6861
第 3/10 轮
9/9 ━━━━━━━━━━━━━━━━━━━━ 10s 700ms/step - 准确率: 0.5402 - 损失: 0.6911
第 4/10 轮
9/9 ━━━━━━━━━━━━━━━━━━━━ 9s 586ms/step - 准确率: 0.5884 - 损失: 0.6788
第 5/10 轮
9/9 ━━━━━━━━━━━━━━━━━━━━ 9s 588ms/step - 准确率: 0.6570 - 损失: 0.6564
第 6/10 轮
9/9 ━━━━━━━━━━━━━━━━━━━━ 10s 591ms/step - 准确率: 0.6671 - 损失: 0.6395
第 7/10 轮
9/9 ━━━━━━━━━━━━━━━━━━━━ 10s 594ms/step - 准确率: 0.7010 - 损失: 0.6161
第 8/10 轮
9/9 ━━━━━━━━━━━━━━━━━━━━ 9s 593ms/step - 准确率: 0.6946 - 损失: 0.6129
第 9/10 轮
9/9 ━━━━━━━━━━━━━━━━━━━━ 10s 699ms/step - 准确率: 0.6972 - 损失: 0.5987
第 10/10 轮
9/9 ━━━━━━━━━━━━━━━━━━━━ 11s 623ms/step - 准确率: 0.6839 - 损失: 0.6197
我们现在将使用梯度中心化训练相同的模型,注意这次我们的优化器是使用梯度中心化的。
time_callback_gc = TimeHistory()
model.compile(loss="binary_crossentropy", optimizer=optimizer, metrics=["accuracy"])
model.summary()
history_gc = model.fit(train_ds, epochs=10, verbose=1, callbacks=[time_callback_gc])
模型: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓ ┃ 层 (类型) ┃ 输出形状 ┃ 参数 # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩ │ conv2d (卷积层) │ (None, 298, 298, 16) │ 448 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ max_pooling2d (最大池化层) │ (None, 149, 149, 16) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv2d_1 (卷积层) │ (None, 147, 147, 32) │ 4,640 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ dropout (丢弃层) │ (None, 147, 147, 32) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ max_pooling2d_1 (最大池化层) │ (None, 73, 73, 32) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv2d_2 (卷积层) │ (None, 71, 71, 64) │ 18,496 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ dropout_1 (丢弃层) │ (None, 71, 71, 64) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ max_pooling2d_2 (最大池化层) │ (None, 35, 35, 64) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv2d_3 (卷积层) │ (None, 33, 33, 64) │ 36,928 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ max_pooling2d_3 (MaxPooling2D) │ (None, 16, 16, 64) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv2d_4 (Conv2D) │ (None, 14, 14, 64) │ 36,928 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ max_pooling2d_4 (MaxPooling2D) │ (None, 7, 7, 64) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ flatten (Flatten) │ (None, 3136) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ dropout_2 (Dropout) │ (None, 3136) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ dense (Dense) │ (None, 512) │ 1,606,144 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ dense_1 (Dense) │ (None, 1) │ 513 │ └─────────────────────────────────┴───────────────────────────┴────────────┘
总参数: 1,704,097 (6.50 MB)
可训练参数: 1,704,097 (6.50 MB)
不可训练参数: 0 (0.00 B)
纪元 1/10
9/9 ━━━━━━━━━━━━━━━━━━━━ 12s 649ms/步 - 准确率: 0.7118 - 损失: 0.5594
纪元 2/10
9/9 ━━━━━━━━━━━━━━━━━━━━ 10s 592ms/步 - 准确率: 0.7249 - 损失: 0.5817
纪元 3/10
9/9 ━━━━━━━━━━━━━━━━━━━━ 9s 587ms/步 - 准确率: 0.8060 - 损失: 0.4448
纪元 4/10
9/9 ━━━━━━━━━━━━━━━━━━━━ 10s 693ms/步 - 准确率: 0.8472 - 损失: 0.4051
纪元 5/10
9/9 ━━━━━━━━━━━━━━━━━━━━ 10s 594ms/步 - 准确率: 0.8386 - 损失: 0.3978
纪元 6/10
9/9 ━━━━━━━━━━━━━━━━━━━━ 10s 593ms/步 - 准确率: 0.8442 - 损失: 0.3976
纪元 7/10
9/9 ━━━━━━━━━━━━━━━━━━━━ 9s 585ms/步 - 准确率: 0.7409 - 损失: 0.6626
纪元 8/10
9/9 ━━━━━━━━━━━━━━━━━━━━ 10s 587ms/步 - 准确率: 0.8191 - 损失: 0.4357
纪元 9/10
9/9 ━━━━━━━━━━━━━━━━━━━━ 9s 587ms/步 - 准确率: 0.8248 - 损失: 0.3974
纪元 10/10
9/9 ━━━━━━━━━━━━━━━━━━━━ 10s 646ms/步 - 准确率: 0.8022 - 损失: 0.4589
print("未使用梯度中心化")
print(f"损失: {history_no_gc.history['loss'][-1]}")
print(f"准确率: {history_no_gc.history['accuracy'][-1]}")
print(f"训练时间: {sum(time_callback_no_gc.times)}")
print("使用梯度中心化")
print(f"损失: {history_gc.history['loss'][-1]}")
print(f"准确率: {history_gc.history['accuracy'][-1]}")
print(f"训练时间: {sum(time_callback_gc.times)}")
未使用梯度中心化
损失: 0.5345584154129028
准确率: 0.7604166865348816
训练时间: 112.48799777030945
使用梯度中心化
损失: 0.4014038145542145
准确率: 0.8153935074806213
训练时间: 98.31573963165283
鼓励读者在来自不同领域的不同数据集上尝试梯度中心化,并实验其效果。强烈建议您查看原始论文——作者展示了几个关于梯度中心化的研究,表明它可以提高整体性能、泛化能力、训练时间以及更高的效率。
非常感谢Ali Mustufa Shaikh对本实现的审阅。