作者: Khalid Salama
创建日期: 2020/11/30
最后修改: 2020/11/30
描述: 使用监督对比学习进行图像分类。
监督对比学习 (Prannay Khosla et al.) 是一种训练方法,优于使用交叉熵的监督训练在分类任务上的效果。
本质上,使用监督对比学习训练图像分类模型分为两个阶段:
注意,此示例需要 TensorFlow Addons,您可以使用以下命令安装:
pip install tensorflow-addons
import tensorflow as tf
import tensorflow_addons as tfa
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers
num_classes = 10
input_shape = (32, 32, 3)
# 加载训练和测试数据集
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
# 显示训练和测试数据集的形状
print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")
x_train shape: (50000, 32, 32, 3) - y_train shape: (50000, 1)
x_test shape: (10000, 32, 32, 3) - y_test shape: (10000, 1)
data_augmentation = keras.Sequential(
[
layers.Normalization(),
layers.RandomFlip("horizontal"),
layers.RandomRotation(0.02),
]
)
# 设置归一化层的状态。
data_augmentation.layers[0].adapt(x_train)
编码器模型以图像为输入,将其转换为2048维的特征向量。
def create_encoder():
resnet = keras.applications.ResNet50V2(
include_top=False, weights=None, input_shape=input_shape, pooling="avg"
)
inputs = keras.Input(shape=input_shape)
augmented = data_augmentation(inputs)
outputs = resnet(augmented)
model = keras.Model(inputs=inputs, outputs=outputs, name="cifar10-encoder")
return model
encoder = create_encoder()
encoder.summary()
learning_rate = 0.001
batch_size = 265
hidden_units = 512
projection_units = 128
num_epochs = 50
dropout_rate = 0.5
temperature = 0.05
Model: "cifar10-encoder"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_2 (InputLayer) [(None, 32, 32, 3)] 0
_________________________________________________________________
sequential (Sequential) (None, None, None, 3) 7
_________________________________________________________________
resnet50v2 (Functional) (None, 2048) 23564800
=================================================================
Total params: 23,564,807
Trainable params: 23,519,360
Non-trainable params: 45,447
_________________________________________________________________
分类模型在编码器顶部添加一个全连接层,以及一个具有目标类别的softmax层。
def create_classifier(encoder, trainable=True):
for layer in encoder.layers:
layer.trainable = trainable
inputs = keras.Input(shape=input_shape)
features = encoder(inputs)
features = layers.Dropout(dropout_rate)(features)
features = layers.Dense(hidden_units, activation="relu")(features)
features = layers.Dropout(dropout_rate)(features)
outputs = layers.Dense(num_classes, activation="softmax")(features)
model = keras.Model(inputs=inputs, outputs=outputs, name="cifar10-classifier")
model.compile(
optimizer=keras.optimizers.Adam(learning_rate),
loss=keras.losses.SparseCategoricalCrossentropy(),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
return model
在这个实验中,基线分类器按常规进行训练,即编码器和分类器部分作为一个单一模型一起训练,以最小化交叉熵损失。
encoder = create_encoder()
classifier = create_classifier(encoder)
classifier.summary()
history = classifier.fit(x=x_train, y=y_train, batch_size=batch_size, epochs=num_epochs)
accuracy = classifier.evaluate(x_test, y_test)[1]
print(f"测试准确率: {round(accuracy * 100, 2)}%")
模型: "cifar10-classifier"
_________________________________________________________________
层 (类型) 输出形状 参数 #
=================================================================
input_5 (输入层) [(无, 32, 32, 3)] 0
_________________________________________________________________
cifar10-encoder (功能性) (无, 2048) 23564807
_________________________________________________________________
dropout (丢弃层) (无, 2048) 0
_________________________________________________________________
dense (稠密层) (无, 512) 1049088
_________________________________________________________________
dropout_1 (丢弃层) (无, 512) 0
_________________________________________________________________
dense_1 (稠密层) (无, 10) 5130
=================================================================
总参数: 24,619,025
可训练参数: 24,573,578
不可训练参数: 45,447
_________________________________________________________________
第 1/50 轮
189/189 [==============================] - 15s 77ms/步 - 损失: 1.9369 - 稀疏分类准确率: 0.2874
第 2/50 轮
189/189 [==============================] - 11s 57ms/步 - 损失: 1.5133 - 稀疏分类准确率: 0.4505
第 3/50 轮
189/189 [==============================] - 11s 57ms/步 - 损失: 1.3468 - 稀疏分类准确率: 0.5204
第 4/50 轮
189/189 [==============================] - 11s 60ms/步 - 损失: 1.2159 - 稀疏分类准确率: 0.5733
第 5/50 轮
189/189 [==============================] - 11s 56ms/步 - 损失: 1.1516 - 稀疏分类准确率: 0.6032
第 6/50 轮
189/189 [==============================] - 11s 58ms/步 - 损失: 1.0769 - 稀疏分类准确率: 0.6254
第 7/50 轮
189/189 [==============================] - 11s 58ms/步 - 损失: 0.9964 - 稀疏分类准确率: 0.6547
第 8/50 轮
189/189 [==============================] - 10s 55ms/步 - 损失: 0.9563 - 稀疏分类准确率: 0.6703
第 9/50 轮
189/189 [==============================] - 10s 55ms/步 - 损失: 0.8952 - 稀疏分类准确率: 0.6925
第 10/50 轮
189/189 [==============================] - 11s 56ms/步 - 损失: 0.8986 - 稀疏分类准确率: 0.6922
第 11/50 轮
189/189 [==============================] - 10s 55ms/步 - 损失: 0.8381 - 稀疏分类准确率: 0.7145
第 12/50 轮
189/189 [==============================] - 10s 55ms/步 - 损失: 0.8513 - 稀疏分类准确率: 0.7086
第 13/50 轮
189/189 [==============================] - 11s 56ms/步 - 损失: 0.7557 - 稀疏分类准确率: 0.7448
第 14/50 轮
189/189 [==============================] - 11s 56ms/步 - 损失: 0.7168 - 稀疏分类准确率: 0.7548
第 15/50 轮
189/189 [==============================] - 10s 55ms/步 - 损失: 0.6772 - 稀疏分类准确率: 0.7690
第 16/50 轮
189/189 [==============================] - 11s 56ms/步 - 损失: 0.7587 - 稀疏分类准确率: 0.7416
第 17/50 轮
189/189 [==============================] - 10s 55ms/步 - 损失: 0.6873 - 稀疏分类准确率: 0.7665
第 18/50 轮
189/189 [==============================] - 11s 56ms/步 - 损失: 0.6418 - 稀疏分类准确率: 0.7804
第 19/50 轮
189/189 [==============================] - 11s 56ms/步 - 损失: 0.6086 - 稀疏分类准确率: 0.7927
第 20/50 轮
189/189 [==============================] - 10s 55ms/步 - 损失: 0.5903 - 稀疏分类准确率: 0.7978
第 21/50 轮
189/189 [==============================] - 11s 56ms/步 - 损失: 0.5636 - 稀疏分类准确率: 0.8083
第 22/50 轮
189/189 [==============================] - 11s 56ms/步 - 损失: 0.5527 - 稀疏分类准确率: 0.8123
第 23/50 轮
189/189 [==============================] - 11s 56ms/步 - 损失: 0.5308 - 稀疏分类准确率: 0.8191
第 24/50 轮
189/189 [==============================] - 10s 55ms/步 - 损失: 0.5282 - 稀疏分类准确率: 0.8223
第 25/50 轮
189/189 [==============================] - 10s 55ms/步 - 损失: 0.5090 - 稀疏分类准确率: 0.8263
第 26/50 轮
189/189 [==============================] - 10s 55ms/步 - 损失: 0.5497 - 稀疏分类准确率: 0.8181
第 27/50 轮
189/189 [==============================] - 10s 55ms/步 - 损失: 0.4950 - 稀疏分类准确率: 0.8332
第 28/50 轮
189/189 [==============================] - 11s 56ms/步 - 损失: 0.4727 - 稀疏分类准确率: 0.8391
第 29/50 轮
167/189 [=========================>....] - ETA: 1s - 损失: 0.4594 - 稀疏分类准确率: 0.8444
在此实验中,模型分为两个阶段进行训练。第一阶段,编码器被预训练以优化监督对比损失,如Prannay Khosla et al.中所述。
在第二阶段,分类器使用训练好的编码器进行训练,编码器的权重被冻结;仅优化带有softmax的全连接层的权重。
class SupervisedContrastiveLoss(keras.losses.Loss):
def __init__(self, temperature=1, name=None):
super().__init__(name=name)
self.temperature = temperature
def __call__(self, labels, feature_vectors, sample_weight=None):
# 规范化特征向量
feature_vectors_normalized = tf.math.l2_normalize(feature_vectors, axis=1)
# 计算logits
logits = tf.divide(
tf.matmul(
feature_vectors_normalized, tf.transpose(feature_vectors_normalized)
),
self.temperature,
)
return tfa.losses.npairs_loss(tf.squeeze(labels), logits)
def add_projection_head(encoder):
inputs = keras.Input(shape=input_shape)
features = encoder(inputs)
outputs = layers.Dense(projection_units, activation="relu")(features)
model = keras.Model(
inputs=inputs, outputs=outputs, name="cifar-encoder_with_projection-head"
)
return model
encoder = create_encoder()
encoder_with_projection_head = add_projection_head(encoder)
encoder_with_projection_head.compile(
optimizer=keras.optimizers.Adam(learning_rate),
loss=SupervisedContrastiveLoss(temperature),
)
encoder_with_projection_head.summary()
history = encoder_with_projection_head.fit(
x=x_train, y=y_train, batch_size=batch_size, epochs=num_epochs
)
classifier = create_classifier(encoder, trainable=False)
history = classifier.fit(x=x_train, y=y_train, batch_size=batch_size, epochs=num_epochs)
accuracy = classifier.evaluate(x_test, y_test)[1]
print(f"测试准确率: {round(accuracy * 100, 2)}%")
Epoch 1/50
189/189 [==============================] - 3s 16ms/step - loss: 0.3979 - sparse_categorical_accuracy: 0.8869
Epoch 2/50
189/189 [==============================] - 3s 16ms/step - loss: 0.3422 - sparse_categorical_accuracy: 0.8959
Epoch 3/50
189/189 [==============================] - 3s 16ms/step - loss: 0.3251 - sparse_categorical_accuracy: 0.9004
Epoch 4/50
189/189 [==============================] - 3s 16ms/step - loss: 0.3313 - sparse_categorical_accuracy: 0.8963
Epoch 5/50
189/189 [==============================] - 3s 16ms/step - loss: 0.3213 - sparse_categorical_accuracy: 0.9006
Epoch 6/50
189/189 [==============================] - 3s 16ms/step - loss: 0.3221 - sparse_categorical_accuracy: 0.9001
Epoch 7/50
189/189 [==============================] - 3s 16ms/step - loss: 0.3134 - sparse_categorical_accuracy: 0.9001
Epoch 8/50
189/189 [==============================] - 3s 16ms/step - loss: 0.3245 - sparse_categorical_accuracy: 0.8978
Epoch 9/50
189/189 [==============================] - 3s 16ms/step - loss: 0.3144 - sparse_categorical_accuracy: 0.9001
Epoch 10/50
189/189 [==============================] - 3s 16ms/step - loss: 0.3191 - sparse_categorical_accuracy: 0.8984
Epoch 11/50
189/189 [==============================] - 3s 16ms/step - loss: 0.3104 - sparse_categorical_accuracy: 0.9025
Epoch 12/50
189/189 [==============================] - 3s 16ms/step - loss: 0.3261 - sparse_categorical_accuracy: 0.8958
Epoch 13/50
189/189 [==============================] - 3s 16ms/step - loss: 0.3130 - sparse_categorical_accuracy: 0.9001
Epoch 14/50
189/189 [==============================] - 3s 16ms/step - loss: 0.3147 - sparse_categorical_accuracy: 0.9003
Epoch 15/50
189/189 [==============================] - 3s 16ms/step - loss: 0.3113 - sparse_categorical_accuracy: 0.9016
Epoch 16/50
189/189 [==============================] - 3s 16ms/step - loss: 0.3114 - sparse_categorical_accuracy: 0.9008
Epoch 17/50
189/189 [==============================] - 3s 16ms/step - loss: 0.3044 - sparse_categorical_accuracy: 0.9026
Epoch 18/50
189/189 [==============================] - 3s 16ms/step - loss: 0.3142 - sparse_categorical_accuracy: 0.8987
Epoch 19/50
189/189 [==============================] - 3s 16ms/step - loss: 0.3139 - sparse_categorical_accuracy: 0.9018
Epoch 20/50
189/189 [==============================] - 3s 16ms/step - loss: 0.3199 - sparse_categorical_accuracy: 0.8987
Epoch 21/50
189/189 [==============================] - 3s 16ms/step - loss: 0.3125 - sparse_categorical_accuracy: 0.8994
Epoch 22/50
189/189 [==============================] - 3s 16ms/step - loss: 0.3291 - sparse_categorical_accuracy: 0.8967
Epoch 23/50
189/189 [==============================] - 3s 16ms/step - loss: 0.3208 - sparse_categorical_accuracy: 0.8963
Epoch 24/50
189/189 [==============================] - 3s 16ms/step - loss: 0.3065 - sparse_categorical_accuracy: 0.9041
Epoch 25/50
189/189 [==============================] - 3s 16ms/step - loss: 0.3099 - sparse_categorical_accuracy: 0.9006
Epoch 26/50
189/189 [==============================] - 3s 16ms/step - loss: 0.3181 - sparse_categorical_accuracy: 0.8986
Epoch 27/50
189/189 [==============================] - 3s 16ms/step - loss: 0.3112 - sparse_categorical_accuracy: 0.9013
Epoch 28/50
189/189 [==============================] - 3s 16ms/step - loss: 0.3136 - sparse_categorical_accuracy: 0.8996
Epoch 29/50
189/189 [==============================] - 3s 16ms/step - loss: 0.3217 - sparse_categorical_accuracy: 0.8969
Epoch 30/50
189/189 [==============================] - 3s 16ms/step - loss: 0.3161 - sparse_categorical_accuracy: 0.8998
Epoch 31/50
189/189 [==============================] - 3s 16ms/step - loss: 0.3151 - sparse_categorical_accuracy: 0.8999
Epoch 32/50
189/189 [==============================] - 3s 16ms/step - loss: 0.3092 - sparse_categorical_accuracy: 0.9009
Epoch 33/50
189/189 [==============================] - 3s 16ms/step - loss: 0.3246 - sparse_categorical_accuracy: 0.8961
Epoch 34/50
189/189 [==============================] - 3s 16ms/step - loss: 0.3143 - sparse_categorical_accuracy: 0.8995
Epoch 35/50
189/189 [==============================] - 3s 16ms/step - loss: 0.3106 - sparse_categorical_accuracy: 0.9002
Epoch 36/50
189/189 [==============================] - 3s 16ms/step - loss: 0.3210 - sparse_categorical_accuracy: 0.8980
Epoch 37/50
189/189 [==============================] - 3s 16ms/step - loss: 0.3178 - sparse_categorical_accuracy: 0.9009
Epoch 38/50
189/189 [==============================] - 3s 16ms/step - loss: 0.3064 - sparse_categorical_accuracy: 0.9032
Epoch 39/50
189/189 [==============================] - 3s 16ms/step - loss: 0.3196 - sparse_categorical_accuracy: 0.8981
Epoch 40/50
189/189 [==============================] - 3s 16ms/step - loss: 0.3177 - sparse_categorical_accuracy: 0.8988
Epoch 41/50
189/189 [==============================] - 3s 16ms/step - loss: 0.3167 - sparse_categorical_accuracy: 0.8987
Epoch 42/50
189/189 [==============================] - 3s 16ms/step - loss: 0.3110 - sparse_categorical_accuracy: 0.9014
Epoch 43/50
189/189 [==============================] - 3s 16ms/step - loss: 0.3124 - sparse_categorical_accuracy: 0.9002
Epoch 44/50
189/189 [==============================] - 3s 16ms/step - loss: 0.3128 - sparse_categorical_accuracy: 0.8999
Epoch 45/50
189/189 [==============================] - 3s 16ms/step - loss: 0.3131 - sparse_categorical_accuracy: 0.8991
Epoch 46/50
189/189 [==============================] - 3s 16ms/step - loss: 0.3149 - sparse_categorical_accuracy: 0.8992
Epoch 47/50
189/189 [==============================] - 3s 16ms/step - loss: 0.3082 - sparse_categorical_accuracy: 0.9021
Epoch 48/50
189/189 [==============================] - 3s 16ms/step - loss: 0.3223 - sparse_categorical_accuracy: 0.8959
Epoch 49/50
189/189 [==============================] - 3s 16ms/step - loss: 0.3195 - sparse_categorical_accuracy: 0.8981
Epoch 50/50
189/189 [==============================] - 3s 16ms/step - loss: 0.3240 - sparse_categorical_accuracy: 0.8962
313/313 [==============================] - 2s 7ms/step - loss: 0.7332 - sparse_categorical_accuracy: 0.8162
测试准确率: 81.62%
我们获得了更高的测试准确率。
正如实验所示,使用监督对比学习技术在测试准确率方面优于传统技术。请注意,每种技术都给予了相同的训练预算(即,训练周期数)。当编码器涉及复杂架构(如 ResNet)和多类别问题具有多个标签时,监督对比学习会得到回报。此外,大批量和多层投影头提高了其有效性。有关更多细节,请参见监督对比学习论文。
您可以使用托管在Hugging Face Hub上的训练模型,并在Hugging Face Spaces上尝试演示。