作者: Sayak Paul
创建日期: 2021/06/19
最后修改: 2021/06/19
描述: 使用 AdaMatch 统一半监督学习和无监督领域适应。
在此示例中,我们将实现 AdaMatch 算法,该算法由 Berthelot 等人在AdaMatch: A Unified Approach to Semi-Supervised Learning and Domain Adaptation中提出。它在无监督领域适应的最新技术中树立了一个新标杆(截至 2021 年 6 月)。AdaMatch 特别有趣,因为它将半监督学习(SSL)和无监督领域适应(UDA)统一在一个框架下,从而提供了一种执行半监督领域适应(SSDA)的方法。
此示例需要 TensorFlow 2.5 或更高版本,以及 TensorFlow 模型,可以使用以下命令安装:
!pip install -q tf-models-official==2.9.2
在继续之前,让我们回顾一下此示例的一些初步概念。
在 半监督学习 (SSL) 中,我们使用少量标签数据在更大的未标记数据集上训练模型。用于计算机视觉的流行半监督学习方法包括 FixMatch, MixMatch, Noisy Student Training 等。您可以参考 这个示例 了解标准 SSL 工作流程的样子。
在 无监督领域适应 中,我们可以访问一个源标签数据集和一个目标 未标记 数据集。然后任务是学习一个能够很好泛化到目标数据集的模型。源数据集和目标数据集在分布上存在差异。以下图提供了此思想的示例。在当前示例中,我们使用 MNIST 数据集 作为源数据集,而目标数据集则是 SVHN,该数据集包含房屋号码的图像。两个数据集在纹理、视角、外观等方面存在各种变化因素:它们的领域或分布彼此不同。
深度学习中流行的领域适应算法包括 Deep CORAL, Moment Matching 等。
import tensorflow as tf
tf.random.set_seed(42)
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import regularizers
from keras_cv.layers import RandAugment
import tensorflow_datasets as tfds
tfds.disable_progress_bar()
# MNIST
(
(mnist_x_train, mnist_y_train),
(mnist_x_test, mnist_y_test),
) = keras.datasets.mnist.load_data()
# 添加一个通道维度
mnist_x_train = tf.expand_dims(mnist_x_train, -1)
mnist_x_test = tf.expand_dims(mnist_x_test, -1)
# 将标签转换为独热编码向量
mnist_y_train = tf.one_hot(mnist_y_train, 10).numpy()
# SVHN
svhn_train, svhn_test = tfds.load(
"svhn_cropped", split=["train", "test"], as_supervised=True
)
RESIZE_TO = 32
SOURCE_BATCH_SIZE = 64
TARGET_BATCH_SIZE = 3 * SOURCE_BATCH_SIZE # 参考: 第 3.2 节
EPOCHS = 10
STEPS_PER_EPOCH = len(mnist_x_train) // SOURCE_BATCH_SIZE
TOTAL_STEPS = EPOCHS * STEPS_PER_EPOCH
AUTO = tf.data.AUTOTUNE
LEARNING_RATE = 0.03
WEIGHT_DECAY = 0.0005
INIT = "he_normal"
DEPTH = 28
WIDTH_MULT = 2
SSL 算法的标准元素是将弱增强和强增强版本的同一图像馈送到学习模型,以使其预测一致。对于强增强,RandAugment 是一个标准选择。对于弱增强,我们将使用水平翻转和随机裁剪。
# Initialize `RandAugment` object with 2 layers of
# augmentation transforms and strength of 5.
augmenter = RandAugment(value_range=(0, 255), augmentations_per_image=2, magnitude=0.5)
def weak_augment(image, source=True):
if image.dtype != tf.float32:
image = tf.cast(image, tf.float32)
# MNIST images are grayscale, this is why we first convert them to
# RGB images.
if source:
image = tf.image.resize_with_pad(image, RESIZE_TO, RESIZE_TO)
image = tf.tile(image, [1, 1, 3])
image = tf.image.random_flip_left_right(image)
image = tf.image.random_crop(image, (RESIZE_TO, RESIZE_TO, 3))
return image
def strong_augment(image, source=True):
if image.dtype != tf.float32:
image = tf.cast(image, tf.float32)
if source:
image = tf.image.resize_with_pad(image, RESIZE_TO, RESIZE_TO)
image = tf.tile(image, [1, 1, 3])
image = augmenter(image)
return image
def create_individual_ds(ds, aug_func, source=True):
if source:
batch_size = SOURCE_BATCH_SIZE
else:
# 在训练期间,向模型展示3倍的目标未标记样本
# 在AdaMatch中(论文的第3.2节)。
batch_size = TARGET_BATCH_SIZE
ds = ds.shuffle(batch_size * 10, seed=42)
if source:
ds = ds.map(lambda x, y: (aug_func(x), y), num_parallel_calls=AUTO)
else:
ds = ds.map(lambda x, y: (aug_func(x, False), y), num_parallel_calls=AUTO)
ds = ds.batch(batch_size).prefetch(AUTO)
return ds
_w
和 _s
后缀分别表示弱增强和强增强。
source_ds = tf.data.Dataset.from_tensor_slices((mnist_x_train, mnist_y_train))
source_ds_w = create_individual_ds(source_ds, weak_augment)
source_ds_s = create_individual_ds(source_ds, strong_augment)
final_source_ds = tf.data.Dataset.zip((source_ds_w, source_ds_s))
target_ds_w = create_individual_ds(svhn_train, weak_augment, source=False)
target_ds_s = create_individual_ds(svhn_train, strong_augment, source=False)
final_target_ds = tf.data.Dataset.zip((target_ds_w, target_ds_s))
下面是单个图像批次的样子:
def compute_loss_source(source_labels, logits_source_w, logits_source_s):
loss_func = keras.losses.CategoricalCrossentropy(from_logits=True)
# 首先计算原始源标签与
# 在相同图像的弱增强和强增强版本上生成的预测之间的损失。
w_loss = loss_func(source_labels, logits_source_w)
s_loss = loss_func(source_labels, logits_source_s)
return w_loss + s_loss
def compute_loss_target(target_pseudo_labels_w, logits_target_s, mask):
loss_func = keras.losses.CategoricalCrossentropy(from_logits=True, reduction="none")
target_pseudo_labels_w = tf.stop_gradient(target_pseudo_labels_w)
# 对于目标样本的损失计算,我们将伪标签
# 视为真实标签。这些在反向传播中不被考虑,
# 这是标准的SSL实践。
target_loss = loss_func(target_pseudo_labels_w, logits_target_s)
# 稍后会更多地讨论`mask`。
mask = tf.cast(mask, target_loss.dtype)
target_loss *= mask
return tf.reduce_mean(target_loss, 0)
下图展示了AdaMatch的整体工作流程(摘自 原始论文):
下面是工作流程的简要分步骤说明:
class AdaMatch(keras.Model):
def __init__(self, model, total_steps, tau=0.9):
super().__init__()
self.model = model
self.tau = tau # 表示置信度阈值
self.loss_tracker = tf.keras.metrics.Mean(name="loss")
self.total_steps = total_steps
self.current_step = tf.Variable(0, dtype="int64")
@property
def metrics(self):
return [self.loss_tracker]
# 这是一个预热计划,用于更新
# 目标未标记样本贡献的损失权重。更多
# 内容将在文本中说明。
def compute_mu(self):
pi = tf.constant(np.pi, dtype="float32")
step = tf.cast(self.current_step, dtype="float32")
return 0.5 - tf.cos(tf.math.minimum(pi, (2 * pi * step) / self.total_steps)) / 2
def train_step(self, data):
## 解包和组织数据 ##
source_ds, target_ds = data
(source_w, source_labels), (source_s, _) = source_ds
(
(target_w, _),
(target_s, _),
) = target_ds # 请注意,我们在这里没有使用任何标签。
combined_images = tf.concat([source_w, source_s, target_w, target_s], 0)
combined_source = tf.concat([source_w, source_s], 0)
total_source = tf.shape(combined_source)[0]
total_target = tf.shape(tf.concat([target_w, target_s], 0))[0]
with tf.GradientTape() as tape:
## 前向传递 ##
combined_logits = self.model(combined_images, training=True)
z_d_prime_source = self.model(
combined_source, training=False
) # 不更新 BatchNorm。
z_prime_source = combined_logits[:total_source]
## 1. 对源图像进行随机logit插值 ##
lambd = tf.random.uniform((total_source, 10), 0, 1)
final_source_logits = (lambd * z_prime_source) + (
(1 - lambd) * z_d_prime_source
)
## 2. 分布对齐(仅考虑弱增强的图像) ##
# 计算弱增强源图像的logits的softmax。
y_hat_source_w = tf.nn.softmax(final_source_logits[: tf.shape(source_w)[0]])
# 提取弱增强目标图像的logits并计算softmax。
logits_target = combined_logits[total_source:]
logits_target_w = logits_target[: tf.shape(target_w)[0]]
y_hat_target_w = tf.nn.softmax(logits_target_w)
# 将目标标签分布与源标签分布对齐。
expectation_ratio = tf.reduce_mean(y_hat_source_w) / tf.reduce_mean(
y_hat_target_w
)
y_tilde_target_w = tf.math.l2_normalize(
y_hat_target_w * expectation_ratio, 1
)
## 3. 相对置信度阈值 ##
row_wise_max = tf.reduce_max(y_hat_source_w, axis=-1)
final_sum = tf.reduce_mean(row_wise_max, 0)
c_tau = self.tau * final_sum
mask = tf.reduce_max(y_tilde_target_w, axis=-1) >= c_tau
## 计算损失(注意索引) ##
source_loss = compute_loss_source(
source_labels,
final_source_logits[: tf.shape(source_w)[0]],
final_source_logits[tf.shape(source_w)[0] :],
)
target_loss = compute_loss_target(
y_tilde_target_w, logits_target[tf.shape(target_w)[0] :], mask
)
t = self.compute_mu() # 计算目标损失的权重
total_loss = source_loss + (t * target_loss)
self.current_step.assign_add(
1
) # 更新调度器的当前训练步骤
gradients = tape.gradient(total_loss, self.model.trainable_variables)
self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
self.loss_tracker.update_state(total_loss)
return {"loss": self.loss_tracker.result()}
作者在论文中介绍了三项改进:
compute_loss_target()
中使用mask
)。在AdaMatch中,这个阈值相对调整,因此称为相对置信度阈值。有关这些方法的更多详细信息,以及它们各自的贡献,请参考论文。
关于 compute_mu()
:
在AdaMatch中使用的是变化的标量,而不是固定的标量量。它表示目标样本带来的损失权重。在视觉上,权重调度器看起来是这样的:
这个调度器在训练的前一半将目标域损失的权重从0增加到1。然后在训练的后一半将该权重保持为1。
作者在我们这个例子中使用了WideResNet-28-2作为数据集对。以下大部分代码参考自this script。请注意,以下模型内部有一个缩放层,用于将像素值缩放到[0, 1]。
def wide_basic(x, n_input_plane, n_output_plane, stride):
conv_params = [[3, 3, stride, "same"], [3, 3, (1, 1), "same"]]
n_bottleneck_plane = n_output_plane
# 残差块
for i, v in enumerate(conv_params):
if i == 0:
if n_input_plane != n_output_plane:
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
convs = x
else:
convs = layers.BatchNormalization()(x)
convs = layers.Activation("relu")(convs)
convs = layers.Conv2D(
n_bottleneck_plane,
(v[0], v[1]),
strides=v[2],
padding=v[3],
kernel_initializer=INIT,
kernel_regularizer=regularizers.l2(WEIGHT_DECAY),
use_bias=False,
)(convs)
else:
convs = layers.BatchNormalization()(convs)
convs = layers.Activation("relu")(convs)
convs = layers.Conv2D(
n_bottleneck_plane,
(v[0], v[1]),
strides=v[2],
padding=v[3],
kernel_initializer=INIT,
kernel_regularizer=regularizers.l2(WEIGHT_DECAY),
use_bias=False,
)(convs)
# 快捷连接:单位函数或1x1卷积
# (取决于输入和输出形状之间的差异 - 这
# 取决于我们是否在
# 每个
# 组中使用第一个块;见 `block_series()`).
if n_input_plane != n_output_plane:
shortcut = layers.Conv2D(
n_output_plane,
(1, 1),
strides=stride,
padding="same",
kernel_initializer=INIT,
kernel_regularizer=regularizers.l2(WEIGHT_DECAY),
use_bias=False,
)(x)
else:
shortcut = x
return layers.Add()([convs, shortcut])
# 在同一阶段堆叠残差单元
def block_series(x, n_input_plane, n_output_plane, count, stride):
x = wide_basic(x, n_input_plane, n_output_plane, stride)
for i in range(2, int(count + 1)):
x = wide_basic(x, n_output_plane, n_output_plane, stride=1)
return x
def get_network(image_size=32, num_classes=10):
n = (DEPTH - 4) / 6
n_stages = [16, 16 * WIDTH_MULT, 32 * WIDTH_MULT, 64 * WIDTH_MULT]
inputs = keras.Input(shape=(image_size, image_size, 3))
x = layers.Rescaling(scale=1.0 / 255)(inputs)
conv1 = layers.Conv2D(
n_stages[0],
(3, 3),
strides=1,
padding="same",
kernel_initializer=INIT,
kernel_regularizer=regularizers.l2(WEIGHT_DECAY),
use_bias=False,
)(x)
## 添加宽残差块 ##
conv2 = block_series(
conv1,
n_input_plane=n_stages[0],
n_output_plane=n_stages[1],
count=n,
stride=(1, 1),
) # 阶段 1
conv3 = block_series(
conv2,
n_input_plane=n_stages[1],
n_output_plane=n_stages[2],
count=n,
stride=(2, 2),
) # 阶段 2
conv4 = block_series(
conv3,
n_input_plane=n_stages[2],
n_output_plane=n_stages[3],
count=n,
stride=(2, 2),
) # 阶段 3
batch_norm = layers.BatchNormalization()(conv4)
relu = layers.Activation("relu")(batch_norm)
# 分类器
trunk_outputs = layers.GlobalAveragePooling2D()(relu)
outputs = layers.Dense(
num_classes, kernel_regularizer=regularizers.l2(WEIGHT_DECAY)
)(trunk_outputs)
return keras.Model(inputs, outputs)
我们现在可以像这样实例化一个宽度的 ResNet 模型。请注意,这里使用宽度 ResNet 的目的是使实现尽可能接近原始版本。
wrn_model = get_network()
print(f"模型的参数数量为 {wrn_model.count_params()/1e6} 百万。")
模型的参数数量为 1.471226 百万。
reduce_lr = keras.optimizers.schedules.CosineDecay(LEARNING_RATE, TOTAL_STEPS, 0.25)
optimizer = keras.optimizers.Adam(reduce_lr)
adamatch_trainer = AdaMatch(model=wrn_model, total_steps=TOTAL_STEPS)
adamatch_trainer.compile(optimizer=optimizer)
total_ds = tf.data.Dataset.zip((final_source_ds, final_target_ds))
adamatch_trainer.fit(total_ds, epochs=EPOCHS)
第 1 轮/10
382/382 [==============================] - 155s 392ms/step - loss: 149259583488.0000
第 2 轮/10
382/382 [==============================] - 145s 379ms/step - loss: 2.0935
第 3 轮/10
382/382 [==============================] - 145s 380ms/step - loss: 1.7237
第 4 轮/10
382/382 [==============================] - 142s 370ms/step - loss: 1.9182
第 5 轮/10
382/382 [==============================] - 141s 367ms/step - loss: 2.9698
第 6 轮/10
382/382 [==============================] - 141s 368ms/step - loss: 3.2622
第 7 轮/10
382/382 [==============================] - 141s 367ms/step - loss: 2.9034
第 8 轮/10
382/382 [==============================] - 141s 368ms/step - loss: 3.2735
第 9 轮/10
382/382 [==============================] - 141s 369ms/step - loss: 3.9449
第 10 轮/10
382/382 [==============================] - 141s 369ms/step - loss: 3.5918
<keras.callbacks.History at 0x7f16eb261e20>
# 编译 AdaMatch 模型以计算准确性。
adamatch_trained_model = adamatch_trainer.model
adamatch_trained_model.compile(metrics=keras.metrics.SparseCategoricalAccuracy())
# 在目标测试集上的得分。
svhn_test = svhn_test.batch(TARGET_BATCH_SIZE).prefetch(AUTO)
_, accuracy = adamatch_trained_model.evaluate(svhn_test)
print(f"目标测试集上的准确率: {accuracy * 100:.2f}%")
136/136 [==============================] - 4s 24ms/step - loss: 508.2073 - sparse_categorical_accuracy: 0.2408
目标测试集上的准确率: 24.08%
经过更多的训练,这个分数会提高。当同样的网络使用标准分类目标进行训练时,其准确率为 7.20%,显著低于我们使用 AdaMatch 得到的结果。你可以查看 这个笔记本 以了解更多关于超参数和其他实验细节的信息。
# 源测试集预处理的实用函数。
def prepare_test_ds_source(image, label):
image = tf.image.resize_with_pad(image, RESIZE_TO, RESIZE_TO)
image = tf.tile(image, [1, 1, 3])
return image, label
source_test_ds = tf.data.Dataset.from_tensor_slices((mnist_x_test, mnist_y_test))
source_test_ds = (
source_test_ds.map(prepare_test_ds_source, num_parallel_calls=AUTO)
.batch(TARGET_BATCH_SIZE)
.prefetch(AUTO)
)
# 在源测试集上的评估。
_, accuracy = adamatch_trained_model.evaluate(source_test_ds)
print(f"源测试集上的准确率: {accuracy * 100:.2f}%")
53/53 [==============================] - 2s 24ms/step - loss: 508.2072 - sparse_categorical_accuracy: 0.9736
源测试集上的准确率: 97.36%
你可以通过使用这些 模型权重 来重现结果。