作者: Abhiraam Eranti
创建日期: 21/4/11
最后修改: 21/20/12
描述: Barlow Twins 的 Keras 实现(通过冗余减少的对比自监督学习)。
自监督学习(SSL)是一种相对新颖的技术,其中模型从未标记的数据中学习,通常在数据损坏或数据量很少时使用。SSL 的一个实际用途是创建从数据中学习的中间嵌入。这些嵌入基于数据集本身,相似的图像具有相似的嵌入,反之亦然。它们然后附加到模型的其余部分,模型使用这些嵌入作为信息,有效地学习并做出正确的预测。这些嵌入理想情况下应该包含尽可能多的信息和对数据的洞察,以便模型能够做出更好的预测。然而,常见的问题是模型创建的嵌入是冗余的。例如,如果两幅图像相似,模型将创建一串1,或者其他包含重复信息的值。这与一热编码或仅有一个位作为模型的表示并无不同;这违背了嵌入的目的,因为它们对数据集的学习程度并不高。对于其他方法,解决此问题的方法是仔细配置模型,使其尽量避免冗余。
Barlow Twins 是解决此问题的一种新方法;虽然其他解决方案主要解决不变性的第一个目标(相似的图像具有相似的嵌入),Barlow Twins 方法同样优先考虑减少冗余的目标。
它还具有比其他方法简单得多的优势,其模型架构是对称的,这意味着模型中的两个“孪生”执行相同的操作。它在 Imagenet 上几乎接近最先进的水平,甚至超过了像 SimCLR 这样的算法。
Barlow Twins 的一个缺点是它在很大程度上依赖于增强,没有增强会导致准确性显著下降。
总结:Barlow Twins 创建的表示是:
此外,它比其他方法更简单。
这个笔记本可以训练一个 Barlow Twins 模型,并在 CIFAR-10 数据集上达到 64% 的验证准确率。
该模型接受同一图像的两个版本(具有不同增强)作为输入。然后它对每个图像进行预测,创建表示。接着,这些表示用于生成交叉相关矩阵。
交叉相关矩阵:
(pred_1.T @ pred_2) / batch_size
交叉相关矩阵衡量模型对两个增强版本的数据的预测所生成的两个表示之间输出神经元的相关性。理想情况下,如果两幅图像相同,交叉相关矩阵应该看起来像一个单位矩阵。
当这种情况发生时,表示:
以下是以伪代码方式理解的好方法(来自原始论文的信息):
c[i][i] = 1
c[i][j] = 0
其中:
c 是交叉相关矩阵
i 是一个表示的神经元的索引
j 是第二个表示的神经元的索引
摘自原始论文:Barlow Twins: Self-Supervised Learning via Redundancy Reduction
论文: Barlow Twins: Self-Supervised Learning via Redundancy Reduction
原始实现: facebookresearch/barlowtwins
!pip install tensorflow-addons
import os
# 略微加快的改进,在第一个周期减少了30秒,在周期时间上减少了1-2秒
# 总体节省约5分钟的训练时间
# 为gpu私有分配两个线程,这样可以更快地完成更多操作
os.environ["TF_GPU_THREAD_MODE"] = "gpu_private"
import tensorflow as tf # 框架
from tensorflow import keras # 用于tf.keras
import tensorflow_addons as tfa # LAMB优化器和gaussian_blur_2d函数
import numpy as np # np.random.random
import matplotlib.pyplot as plt # 图表
import datetime # tensorboard日志命名
# XLA优化以获得更快的性能(节省总时间高达10-15分钟)
tf.config.optimizer.set_jit(True)
['Requirement already satisfied: tensorflow-addons in /usr/local/lib/python3.7/dist-packages (0.15.0)',
'Requirement already satisfied: typeguard>=2.7 in /usr/local/lib/python3.7/dist-packages (from tensorflow-addons) (2.7.1)']
[
(train_features, train_labels),
(test_features, test_labels),
] = keras.datasets.cifar10.load_data()
train_features = train_features / 255.0
test_features = test_features / 255.0
# 数据集的批次大小
BATCH_SIZE = 512
# 图像的宽度和高度
IMAGE_SIZE = 32
巴洛双胞胎算法在很大程度上依赖于数据增强。该方法的一个独特特点是,有时增强是以概率方式发生的。
增强方法
class Augmentation(keras.layers.Layer):
"""Base augmentation class.
Base augmentation class. Contains the random_execute method.
Methods:
random_execute: method that returns true or false based
on a probability. Used to determine whether an augmentation
will be run.
"""
def __init__(self):
super().__init__()
@tf.function
def random_execute(self, prob: float) -> bool:
"""random_execute function.
Arguments:
prob: a float value from 0-1 that determines the
probability.
Returns:
returns true or false based on the probability.
"""
return tf.random.uniform([], minval=0, maxval=1) < prob
class RandomToGrayscale(Augmentation):
"""RandomToGrayscale class.
RandomToGrayscale class. Randomly makes an image
grayscaled based on the random_execute method. There
is a 20% chance that an image will be grayscaled.
Methods:
call: method that grayscales an image 20% of
the time.
"""
@tf.function
def call(self, x: tf.Tensor) -> tf.Tensor:
"""call function.
Arguments:
x: a tf.Tensor representing the image.
Returns:
returns a grayscaled version of the image 20% of the time
and the original image 80% of the time.
"""
if self.random_execute(0.2):
x = tf.image.rgb_to_grayscale(x)
x = tf.tile(x, [1, 1, 3])
return x
class RandomColorJitter(Augmentation):
"""RandomColorJitter class.
RandomColorJitter class. Randomly adds color jitter to an image.
Color jitter means to add random brightness, contrast,
saturation, and hue to an image. There is a 80% chance that an
image will be randomly color-jittered.
Methods:
call: method that color-jitters an image 80% of
the time.
"""
@tf.function
def call(self, x: tf.Tensor) -> tf.Tensor:
"""call function.
Adds color jitter to image, including:
Brightness change by a max-delta of 0.8
Contrast change by a max-delta of 0.8
Saturation change by a max-delta of 0.8
Hue change by a max-delta of 0.2
Originally, the same deltas of the original paper
were used, but a performance boost of almost 2% was found
when doubling them.
Arguments:
x: a tf.Tensor representing the image.
Returns:
returns a color-jittered version of the image 80% of the time
and the original image 20% of the time.
"""
if self.random_execute(0.8):
x = tf.image.random_brightness(x, 0.8)
x = tf.image.random_contrast(x, 0.4, 1.6)
x = tf.image.random_saturation(x, 0.4, 1.6)
x = tf.image.random_hue(x, 0.2)
return x
class RandomFlip(Augmentation):
"""RandomFlip class.
RandomFlip class. Randomly flips image horizontally. There is a 50%
chance that an image will be randomly flipped.
Methods:
call: method that flips an image 50% of
the time.
"""
@tf.function
def call(self, x: tf.Tensor) -> tf.Tensor:
"""call function.
Randomly flips the image.
Arguments:
x: a tf.Tensor representing the image.
Returns:
returns a flipped version of the image 50% of the time
and the original image 50% of the time.
"""
if self.random_execute(0.5):
x = tf.image.random_flip_left_right(x)
return x
class RandomResizedCrop(Augmentation):
"""RandomResizedCrop class.
RandomResizedCrop class. Randomly crop an image to a random size,
then resize the image back to the original size.
Attributes:
image_size: The dimension of the image
Methods:
__call__: method that does random resize crop to the image.
"""
def __init__(self, image_size):
super().__init__()
self.image_size = image_size
def call(self, x: tf.Tensor) -> tf.Tensor:
"""call function.
Does random resize crop by randomly cropping an image to a random
size 75% - 100% the size of the image. Then resizes it.
Arguments:
x: a tf.Tensor representing the image.
Returns:
returns a randomly cropped image.
"""
rand_size = tf.random.uniform(
shape=[],
minval=int(0.75 * self.image_size),
maxval=1 * self.image_size,
dtype=tf.int32,
)
crop = tf.image.random_crop(x, (rand_size, rand_size, 3))
crop_resize = tf.image.resize(crop, (self.image_size, self.image_size))
return crop_resize
class RandomSolarize(Augmentation):
"""RandomSolarize class.
RandomSolarize class. Randomly solarizes an image.
Solarization is when pixels accidentally flip to an inverted state.
Methods:
call: method that does random solarization 20% of the time.
"""
@tf.function
def call(self, x: tf.Tensor) -> tf.Tensor:
"""call function.
Randomly solarizes the image.
Arguments:
x: a tf.Tensor representing the image.
Returns:
returns a solarized version of the image 20% of the time
and the original image 80% of the time.
"""
if self.random_execute(0.2):
# flips abnormally low pixels to abnormally high pixels
x = tf.where(x < 10, x, 255 - x)
return x
class RandomBlur(Augmentation):
"""RandomBlur class.
RandomBlur class. Randomly blurs an image.
Methods:
call: method that does random blur 20% of the time.
"""
@tf.function
def call(self, x: tf.Tensor) -> tf.Tensor:
"""call function.
Randomly solarizes the image.
Arguments:
x: a tf.Tensor representing the image.
Returns:
returns a blurred version of the image 20% of the time
and the original image 80% of the time.
"""
if self.random_execute(0.2):
s = np.random.random()
return tfa.image.gaussian_filter2d(image=x, sigma=s)
return x
class RandomAugmentor(keras.Model):
"""RandomAugmentor class.
RandomAugmentor class. Chains all the augmentations into
one pipeline.
Attributes:
image_size: An integer represing the width and height
of the image. Designed to be used for square images.
random_resized_crop: Instance variable representing the
RandomResizedCrop layer.
random_flip: Instance variable representing the
RandomFlip layer.
random_color_jitter: Instance variable representing the
RandomColorJitter layer.
random_blur: Instance variable representing the
RandomBlur layer
random_to_grayscale: Instance variable representing the
RandomToGrayscale layer
random_solarize: Instance variable representing the
RandomSolarize layer
Methods:
call: chains layers in pipeline together
"""
def __init__(self, image_size: int):
super().__init__()
self.image_size = image_size
self.random_resized_crop = RandomResizedCrop(image_size)
self.random_flip = RandomFlip()
self.random_color_jitter = RandomColorJitter()
self.random_blur = RandomBlur()
self.random_to_grayscale = RandomToGrayscale()
self.random_solarize = RandomSolarize()
def call(self, x: tf.Tensor) -> tf.Tensor:
x = self.random_resized_crop(x)
x = self.random_flip(x)
x = self.random_color_jitter(x)
x = self.random_blur(x)
x = self.random_to_grayscale(x)
x = self.random_solarize(x)
x = tf.clip_by_value(x, 0, 1)
return x
bt_augmentor = RandomAugmentor(IMAGE_SIZE)
一个创建Barlow双胞胎数据集的类。
该数据集由每个图像的两份副本组成,每个副本接受不同的 增强处理。
class BTDatasetCreator:
"""Barlow双胞胎数据集创建类。
BTDatasetCreator类。负责创建
Barlow双胞胎的数据集。
属性:
options: tf.data.Options,用于配置可能提高性能的设置。
seed: 随机种子,用于洗牌。用于同步两个
增强版本。
augmentor: 用于增强的增强器。
方法:
__call__: 创建Barlow数据集。
augmented_version: 创建数据集的一半。
"""
def __init__(self, augmentor: RandomAugmentor, seed: int = 1024):
self.options = tf.data.Options()
self.options.threading.max_intra_op_parallelism = 1
self.seed = seed
self.augmentor = augmentor
def augmented_version(self, ds: list) -> tf.data.Dataset:
return (
tf.data.Dataset.from_tensor_slices(ds)
.shuffle(1000, seed=self.seed)
.map(self.augmentor, num_parallel_calls=tf.data.AUTOTUNE)
.batch(BATCH_SIZE, drop_remainder=True)
.prefetch(tf.data.AUTOTUNE)
.with_options(self.options)
)
def __call__(self, ds: list) -> tf.data.Dataset:
a1 = self.augmented_version(ds)
a2 = self.augmented_version(ds)
return tf.data.Dataset.zip((a1, a2)).with_options(self.options)
augment_versions = BTDatasetCreator(bt_augmentor)(train_features)
查看数据集示例。
sample_augment_versions = iter(augment_versions)
def plot_values(batch: tuple):
fig, axs = plt.subplots(3, 3)
fig1, axs1 = plt.subplots(3, 3)
fig.suptitle("增强 1")
fig1.suptitle("增强 2")
a1, a2 = batch
# 在两个表上绘制图像
for i in range(3):
for j in range(3):
# 改变(添加 / 255)
axs[i][j].imshow(a1[3 * i + j])
axs[i][j].axis("off")
axs1[i][j].imshow(a2[3 * i + j])
axs1[i][j].axis("off")
plt.show()
plot_values(next(sample_augment_versions))
以下部分遵循原作者的伪代码,其中包含模型和 损失函数(见下图)。还包含所用变量的参考。
参考:
y_a: 原始图像的第一个增强版本。
y_b: 原始图像的第二个增强版本。
z_a: y_a的模型表示(嵌入)。
z_b: y_b的模型表示(嵌入)。
z_a_norm: z_a的归一化。
z_b_norm: z_b的归一化。
c: 交叉相关矩阵。
c_diff: 损失的对角部分(不变性项)。
off_diag: 损失的非对角部分(冗余减少项)。
Barlow双胞胎使用交叉相关矩阵作为其损失。损失函数有两个部分:
然后将这两个部分相加。
class BarlowLoss(keras.losses.Loss):
"""BarlowLoss 类。
BarlowLoss 类。基于交叉相关矩阵创建损失函数。
属性:
batch_size: 数据集的批次大小
lambda_amt: lambda 的值(用于 cross_corr_matrix_loss)
方法:
__init__: 获取实例变量
call: 根据交叉相关矩阵获取损失
make_diag_zeros: 用于计算损失函数的非对角部分;将对角线设为零
cross_corr_matrix_loss: 基于交叉相关矩阵创建损失。
"""
def __init__(self, batch_size: int):
"""__init__ 方法。
获取实例变量
参数:
batch_size: 一个整数值,表示数据集的批次大小。用于交叉相关矩阵计算。
"""
super().__init__()
self.lambda_amt = 5e-3
self.batch_size = batch_size
def get_off_diag(self, c: tf.Tensor) -> tf.Tensor:
"""get_off_diag 方法。
将交叉相关矩阵的对角线设为零。
这用于损失函数的非对角部分,我们对非对角值取平方并求和。
参数:
c: 表示交叉相关矩阵的 tf.tensor
返回:
返回一个 tf.tensor,表示对角线为零的交叉相关矩阵。
"""
zero_diag = tf.zeros(c.shape[-1])
return tf.linalg.set_diag(c, zero_diag)
def cross_corr_matrix_loss(self, c: tf.Tensor) -> tf.Tensor:
"""cross_corr_matrix_loss 方法。
根据交叉相关矩阵获取损失。
我们希望对角线为 1,其他所有值为零,以显示这两幅增强图像相似。
损失函数过程:
取交叉相关矩阵的对角线,减去 1,然后平方该值以避免负值。
取 cc 矩阵的非对角部分(见 get_off_diag()),
将这些值平方以消除负值并增加值,
并乘以一个 lambda 值,使其与对角线的值相等(非对角线的值比对角线值多)
将第一部分和第二部分相加,然后再相加。
参数:
c: 表示交叉相关矩阵的 tf.tensor
返回:
返回一个 tf.tensor,表示对角线为零的交叉相关矩阵。
"""
# 将对角线减去一并平方(第一部分)
c_diff = tf.pow(tf.linalg.diag_part(c) - 1, 2)
# 取非对角线,平方并乘以 lambda(第二部分)
off_diag = tf.pow(self.get_off_diag(c), 2) * self.lambda_amt
# 将第一部分和第二部分相加
loss = tf.reduce_sum(c_diff) + tf.reduce_sum(off_diag)
return loss
def normalize(self, output: tf.Tensor) -> tf.Tensor:
"""normalize 方法。
归一化模型预测。
参数:
output: 模型预测。
返回:
返回归一化后的模型预测。
"""
return (output - tf.reduce_mean(output, axis=0)) / tf.math.reduce_std(
output, axis=0
)
def cross_corr_matrix(self, z_a_norm: tf.Tensor, z_b_norm: tf.Tensor) -> tf.Tensor:
"""cross_corr_matrix 方法。
从预测中创建交叉相关矩阵。
它转置第一个预测并与第二个预测相乘,创建形状为 (n_dense_units, n_dense_units) 的矩阵。
有关更多信息,请参见 build_twin()。然后将其除以批次大小。
参数:
z_a_norm: 第一个预测的归一化版本。
z_b_norm: 第二个预测的归一化版本。
返回:
返回一个交叉相关矩阵。
"""
return (tf.transpose(z_a_norm) @ z_b_norm) / self.batch_size
def call(self, z_a: tf.Tensor, z_b: tf.Tensor) -> tf.Tensor:
"""call 方法。
计算交叉相关损失。使用 CreateCrossCorr 类生成交叉相关矩阵,然后找到损失并返回(见 cross_corr_matrix_loss())。
参数:
z_a: 第一组增强数据的预测。
z_b: 第二组增强数据的预测。
返回:
返回一个(秩为 0 的)tf.Tensor,表示损失。
"""
z_a_norm, z_b_norm = self.normalize(z_a), self.normalize(z_b)
c = self.cross_corr_matrix(z_a_norm, z_b_norm)
loss = self.cross_corr_matrix_loss(c)
return loss
模型有两个部分:
Resnet编码器网络实现:
class ResNet34:
"""Resnet34类。
负责Resnet 34架构。
修改自
https://www.analyticsvidhya.com/blog/2021/08/how-to-code-your-resnet-from-scratch-in-tensorflow/#h2_2。
https://www.analyticsvidhya.com/blog/2021/08/how-to-code-your-resnet-from-scratch-in-tensorflow/#h2_2。
更多信息请查看他们的网站。
"""
def identity_block(self, x, filter):
# 将张量复制到名为x_skip的变量
x_skip = x
# 层1
x = tf.keras.layers.Conv2D(filter, (3, 3), padding="same")(x)
x = tf.keras.layers.BatchNormalization(axis=3)(x)
x = tf.keras.layers.Activation("relu")(x)
# 层2
x = tf.keras.layers.Conv2D(filter, (3, 3), padding="same")(x)
x = tf.keras.layers.BatchNormalization(axis=3)(x)
# 添加残差
x = tf.keras.layers.Add()([x, x_skip])
x = tf.keras.layers.Activation("relu")(x)
return x
def convolutional_block(self, x, filter):
# 将张量复制到名为x_skip的变量
x_skip = x
# 层1
x = tf.keras.layers.Conv2D(filter, (3, 3), padding="same", strides=(2, 2))(x)
x = tf.keras.layers.BatchNormalization(axis=3)(x)
x = tf.keras.layers.Activation("relu")(x)
# 层2
x = tf.keras.layers.Conv2D(filter, (3, 3), padding="same")(x)
x = tf.keras.layers.BatchNormalization(axis=3)(x)
# 使用conv(1,1)处理残差
x_skip = tf.keras.layers.Conv2D(filter, (1, 1), strides=(2, 2))(x_skip)
# 添加残差
x = tf.keras.layers.Add()([x, x_skip])
x = tf.keras.layers.Activation("relu")(x)
return x
def __call__(self, shape=(32, 32, 3)):
# 步骤1(设置输入层)
x_input = tf.keras.layers.Input(shape)
x = tf.keras.layers.ZeroPadding2D((3, 3))(x_input)
# 步骤2(初始卷积层和最大池化层)
x = tf.keras.layers.Conv2D(64, kernel_size=7, strides=2, padding="same")(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Activation("relu")(x)
x = tf.keras.layers.MaxPool2D(pool_size=3, strides=2, padding="same")(x)
# 定义子块的大小和初始滤波器大小
block_layers = [3, 4, 6, 3]
filter_size = 64
# 步骤3 添加Resnet块
for i in range(4):
if i == 0:
# 对于子块1不需要残差/卷积块
for j in range(block_layers[i]):
x = self.identity_block(x, filter_size)
else:
# 一个残差/卷积块后跟身份块
# 滤波器大小将以2的倍数增加
filter_size = filter_size * 2
x = self.convolutional_block(x, filter_size)
for j in range(block_layers[i] - 1):
x = self.identity_block(x, filter_size)
# 步骤4 结束密集网络
x = tf.keras.layers.AveragePooling2D((2, 2), padding="same")(x)
x = tf.keras.layers.Flatten()(x)
model = tf.keras.models.Model(inputs=x_input, outputs=x, name="ResNet34")
return model
投影网络:
def build_twin() -> keras.Model:
"""build_twin方法。
构建一个包含编码器(resnet-34)和投影器的barlow twins模型,
用于为图像生成嵌入
返回:
返回一个barlow twins模型
"""
# 投影器中的稠密神经元数量
n_dense_neurons = 5000
# 编码器网络
resnet = ResNet34()()
last_layer = resnet.layers[-1].output
# 投影网络的中间层
n_layers = 2
for i in range(n_layers):
dense = tf.keras.layers.Dense(n_dense_neurons, name=f"projector_dense_{i}")
if i == 0:
x = dense(last_layer)
else:
x = dense(x)
x = tf.keras.layers.BatchNormalization(name=f"projector_bn_{i}")(x)
x = tf.keras.layers.ReLU(name=f"projector_relu_{i}")(x)
x = tf.keras.layers.Dense(n_dense_neurons, name=f"projector_dense_{n_layers}")(x)
model = keras.Model(resnet.input, x)
return model
参见伪代码以供参考。
class BarlowModel(keras.Model):
"""BarlowModel 类。
BarlowModel 类。负责进行预测和处理
优化器的梯度下降。
属性:
model: barlow 模型架构。
loss_tracker: 损失指标。
方法:
train_step: 一个训练步骤;进行模型预测、计算损失和
优化器步骤。
metrics: 返回指标。
"""
def __init__(self):
super().__init__()
self.model = build_twin()
self.loss_tracker = keras.metrics.Mean(name="loss")
@property
def metrics(self):
return [self.loss_tracker]
def train_step(self, batch: tf.Tensor) -> tf.Tensor:
"""train_step 方法。
进行一个训练步骤。进行模型预测,计算损失,将损失传递给
优化器,并使优化器应用梯度。
参数:
batch: 一批数据,用于损失函数。
返回:
返回一个包含损失指标的字典。
"""
# 从批次中获取两个增强版本
y_a, y_b = batch
with tf.GradientTape() as tape:
# 获取两个版本的预测
z_a, z_b = self.model(y_a, training=True), self.model(y_b, training=True)
loss = self.loss(z_a, z_b)
grads_model = tape.gradient(loss, self.model.trainable_variables)
self.optimizer.apply_gradients(zip(grads_model, self.model.trainable_variables))
self.loss_tracker.update_state(loss)
return {"loss": self.loss_tracker.result()}
# 设置模型、优化器、损失
bm = BarlowModel()
# 选择 LAMB 优化器是因为批大小较大。收敛得比 ADAM 或 SGD 快得多
optimizer = tfa.optimizers.LAMB()
loss = BarlowLoss(BATCH_SIZE)
bm.compile(optimizer=optimizer, loss=loss)
# 预计训练时间:1 小时 30 分钟
history = bm.fit(augment_versions, epochs=160)
plt.plot(history.history["loss"])
plt.show()
Epoch 1/160
97/97 [==============================] - 89s 294ms/step - loss: 3480.7588
Epoch 2/160
97/97 [==============================] - 29s 294ms/step - loss: 2163.4197
Epoch 3/160
97/97 [==============================] - 29s 294ms/step - loss: 1939.0248
Epoch 4/160
97/97 [==============================] - 29s 294ms/step - loss: 1810.4800
Epoch 5/160
97/97 [==============================] - 29s 294ms/step - loss: 1725.7401
Epoch 6/160
97/97 [==============================] - 29s 294ms/step - loss: 1658.2261
Epoch 7/160
97/97 [==============================] - 29s 294ms/step - loss: 1592.0747
Epoch 8/160
97/97 [==============================] - 29s 294ms/step - loss: 1545.2579
Epoch 9/160
97/97 [==============================] - 29s 294ms/step - loss: 1509.6631
Epoch 10/160
97/97 [==============================] - 29s 294ms/step - loss: 1484.1141
Epoch 11/160
97/97 [==============================] - 29s 293ms/step - loss: 1456.8615
Epoch 12/160
97/97 [==============================] - 29s 294ms/step - loss: 1430.0315
Epoch 13/160
97/97 [==============================] - 29s 294ms/step - loss: 1418.1147
Epoch 14/160
97/97 [==============================] - 29s 294ms/step - loss: 1385.7473
Epoch 15/160
97/97 [==============================] - 29s 294ms/step - loss: 1362.8176
Epoch 16/160
97/97 [==============================] - 29s 294ms/step - loss: 1353.6069
Epoch 17/160
97/97 [==============================] - 29s 294ms/step - loss: 1331.3687
Epoch 18/160
97/97 [==============================] - 29s 294ms/step - loss: 1323.1509
Epoch 19/160
97/97 [==============================] - 29s 294ms/step - loss: 1309.3015
Epoch 20/160
97/97 [==============================] - 29s 294ms/step - loss: 1303.2418
Epoch 21/160
97/97 [==============================] - 29s 294ms/step - loss: 1278.0450
Epoch 22/160
97/97 [==============================] - 29s 294ms/step - loss: 1272.2640
Epoch 23/160
97/97 [==============================] - 29s 294ms/step - loss: 1259.4225
Epoch 24/160
97/97 [==============================] - 29s 294ms/step - loss: 1246.8461
Epoch 25/160
97/97 [==============================] - 29s 294ms/step - loss: 1235.0269
Epoch 26/160
97/97 [==============================] - 29s 295ms/step - loss: 1228.4196
Epoch 27/160
97/97 [==============================] - 29s 295ms/step - loss: 1220.0851
Epoch 28/160
97/97 [==============================] - 29s 294ms/step - loss: 1208.5876
Epoch 29/160
97/97 [==============================] - 29s 294ms/step - loss: 1203.1449
Epoch 30/160
97/97 [==============================] - 29s 294ms/step - loss: 1199.5155
Epoch 31/160
97/97 [==============================] - 29s 294ms/step - loss: 1183.9818
Epoch 32/160
97/97 [==============================] - 29s 294ms/step - loss: 1173.9989
Epoch 33/160
97/97 [==============================] - 29s 294ms/step - loss: 1171.3789
Epoch 34/160
97/97 [==============================] - 29s 294ms/step - loss: 1160.8230
Epoch 35/160
97/97 [==============================] - 29s 294ms/step - loss: 1159.4148
Epoch 36/160
97/97 [==============================] - 29s 294ms/step - loss: 1148.4250
Epoch 37/160
97/97 [==============================] - 29s 294ms/step - loss: 1138.1802
Epoch 38/160
97/97 [==============================] - 29s 294ms/step - loss: 1135.9139
Epoch 39/160
97/97 [==============================] - 29s 294ms/step - loss: 1126.8186
Epoch 40/160
97/97 [==============================] - 29s 294ms/step - loss: 1119.6173
Epoch 41/160
97/97 [==============================] - 29s 293ms/step - loss: 1113.9358
Epoch 42/160
97/97 [==============================] - 29s 294ms/step - loss: 1106.0131
Epoch 43/160
97/97 [==============================] - 29s 294ms/step - loss: 1104.7386
Epoch 44/160
97/97 [==============================] - 29s 294ms/step - loss: 1097.7909
Epoch 45/160
97/97 [==============================] - 29s 294ms/step - loss: 1091.4229
Epoch 46/160
97/97 [==============================] - 29s 293ms/step - loss: 1082.3530
Epoch 47/160
97/97 [==============================] - 29s 294ms/step - loss: 1081.9459
Epoch 48/160
97/97 [==============================] - 29s 294ms/step - loss: 1078.5864
Epoch 49/160
97/97 [==============================] - 29s 293ms/step - loss: 1075.9255
Epoch 50/160
97/97 [==============================] - 29s 293ms/step - loss: 1070.9954
Epoch 51/160
97/97 [==============================] - 29s 294ms/step - loss: 1061.1058
Epoch 52/160
97/97 [==============================] - 29s 294ms/step - loss: 1055.0126
Epoch 53/160
97/97 [==============================] - 29s 294ms/step - loss: 1045.7827
Epoch 54/160
97/97 [==============================] - 29s 293ms/step - loss: 1047.5338
Epoch 55/160
97/97 [==============================] - 29s 294ms/step - loss: 1043.9012
Epoch 56/160
97/97 [==============================] - 29s 294ms/step - loss: 1044.5902
Epoch 57/160
97/97 [==============================] - 29s 294ms/step - loss: 1038.3389
Epoch 58/160
97/97 [==============================] - 29s 294ms/step - loss: 1032.1195
Epoch 59/160
97/97 [==============================] - 29s 294ms/step - loss: 1026.5962
Epoch 60/160
97/97 [==============================] - 29s 294ms/step - loss: 1018.2954
Epoch 61/160
97/97 [==============================] - 29s 294ms/step - loss: 1014.7681
Epoch 62/160
97/97 [==============================] - 29s 294ms/step - loss: 1007.7906
Epoch 63/160
97/97 [==============================] - 29s 294ms/step - loss: 1012.9134
Epoch 64/160
97/97 [==============================] - 29s 294ms/step - loss: 1009.7881
Epoch 65/160
97/97 [==============================] - 29s 294ms/step - loss: 1003.2436
Epoch 66/160
97/97 [==============================] - 29s 293ms/step - loss: 997.0688
Epoch 67/160
97/97 [==============================] - 29s 294ms/step - loss: 999.1620
Epoch 68/160
97/97 [==============================] - 29s 294ms/step - loss: 993.2636
Epoch 69/160
97/97 [==============================] - 29s 295ms/step - loss: 988.5142
Epoch 70/160
97/97 [==============================] - 29s 294ms/step - loss: 981.5876
Epoch 71/160
97/97 [==============================] - 29s 294ms/step - loss: 978.3053
Epoch 72/160
97/97 [==============================] - 29s 295ms/step - loss: 978.8599
Epoch 73/160
97/97 [==============================] - 29s 294ms/step - loss: 973.7569
Epoch 74/160
97/97 [==============================] - 29s 294ms/step - loss: 971.2402
Epoch 75/160
97/97 [==============================] - 29s 295ms/step - loss: 964.2864
Epoch 76/160
97/97 [==============================] - 29s 294ms/step - loss: 963.4999
Epoch 77/160
97/97 [==============================] - 29s 294ms/step - loss: 959.7264
Epoch 78/160
97/97 [==============================] - 29s 294ms/step - loss: 958.1680
Epoch 79/160
97/97 [==============================] - 29s 295ms/step - loss: 952.0243
Epoch 80/160
97/97 [==============================] - 29s 295ms/step - loss: 947.8354
Epoch 81/160
97/97 [==============================] - 29s 295ms/step - loss: 945.8139
Epoch 82/160
97/97 [==============================] - 29s 294ms/step - loss: 944.9114
Epoch 83/160
97/97 [==============================] - 29s 294ms/step - loss: 940.7040
Epoch 84/160
97/97 [==============================] - 29s 295ms/step - loss: 942.7839
Epoch 85/160
97/97 [==============================] - 29s 295ms/step - loss: 937.4374
Epoch 86/160
97/97 [==============================] - 29s 295ms/step - loss: 934.6262
Epoch 87/160
97/97 [==============================] - 29s 295ms/step - loss: 929.8491
Epoch 88/160
97/97 [==============================] - 29s 294ms/step - loss: 937.7441
Epoch 89/160
97/97 [==============================] - 29s 295ms/step - loss: 927.0290
Epoch 90/160
97/97 [==============================] - 29s 295ms/step - loss: 925.6105
Epoch 91/160
97/97 [==============================] - 29s 294ms/step - loss: 921.6296
Epoch 92/160
97/97 [==============================] - 29s 294ms/step - loss: 925.8184
Epoch 93/160
97/97 [==============================] - 29s 294ms/step - loss: 912.5261
Epoch 94/160
97/97 [==============================] - 29s 295ms/step - loss: 915.6510
Epoch 95/160
97/97 [==============================] - 29s 295ms/step - loss: 909.5853
Epoch 96/160
97/97 [==============================] - 29s 294ms/step - loss: 911.1563
Epoch 97/160
97/97 [==============================] - 29s 295ms/step - loss: 906.8965
Epoch 98/160
97/97 [==============================] - 29s 294ms/step - loss: 902.3696
Epoch 99/160
97/97 [==============================] - 29s 295ms/step - loss: 899.8710
Epoch 100/160
97/97 [==============================] - 29s 294ms/step - loss: 894.1641
Epoch 101/160
97/97 [==============================] - 29s 294ms/step - loss: 895.7336
Epoch 102/160
97/97 [==============================] - 29s 294ms/step - loss: 900.1674
Epoch 103/160
97/97 [==============================] - 29s 294ms/step - loss: 887.2552
Epoch 104/160
97/97 [==============================] - 29s 295ms/step - loss: 893.1448
Epoch 105/160
97/97 [==============================] - 29s 294ms/step - loss: 889.9379
Epoch 106/160
97/97 [==============================] - 29s 295ms/step - loss: 884.9587
Epoch 107/160
97/97 [==============================] - 29s 294ms/step - loss: 880.9834
Epoch 108/160
97/97 [==============================] - 29s 295ms/step - loss: 883.2829
Epoch 109/160
97/97 [==============================] - 29s 294ms/step - loss: 876.6734
Epoch 110/160
97/97 [==============================] - 29s 294ms/step - loss: 873.4252
Epoch 111/160
97/97 [==============================] - 29s 294ms/step - loss: 873.2639
Epoch 112/160
97/97 [==============================] - 29s 295ms/step - loss: 871.0381
Epoch 113/160
97/97 [==============================] - 29s 294ms/step - loss: 866.5417
Epoch 114/160
97/97 [==============================] - 29s 294ms/step - loss: 862.2125
Epoch 115/160
97/97 [==============================] - 29s 294ms/step - loss: 862.8839
Epoch 116/160
97/97 [==============================] - 29s 294ms/step - loss: 861.1781
Epoch 117/160
97/97 [==============================] - 29s 294ms/step - loss: 856.6186
Epoch 118/160
97/97 [==============================] - 29s 294ms/step - loss: 857.3196
Epoch 119/160
97/97 [==============================] - 29s 294ms/step - loss: 858.0576
Epoch 120/160
97/97 [==============================] - 29s 294ms/step - loss: 855.3264
Epoch 121/160
97/97 [==============================] - 29s 294ms/step - loss: 850.6841
Epoch 122/160
97/97 [==============================] - 29s 294ms/step - loss: 849.6420
Epoch 123/160
97/97 [==============================] - 29s 294ms/step - loss: 846.6933
Epoch 124/160
97/97 [==============================] - 29s 295ms/step - loss: 847.4681
Epoch 125/160
97/97 [==============================] - 29s 294ms/step - loss: 838.5893
Epoch 126/160
97/97 [==============================] - 29s 294ms/step - loss: 841.2516
Epoch 127/160
97/97 [==============================] - 29s 295ms/step - loss: 840.6940
Epoch 128/160
97/97 [==============================] - 29s 294ms/step - loss: 840.9053
Epoch 129/160
97/97 [==============================] - 29s 294ms/step - loss: 836.9998
Epoch 130/160
97/97 [==============================] - 29s 294ms/step - loss: 836.6874
Epoch 131/160
97/97 [==============================] - 29s 294ms/step - loss: 835.2166
Epoch 132/160
97/97 [==============================] - 29s 295ms/step - loss: 833.7071
Epoch 133/160
97/97 [==============================] - 29s 294ms/step - loss: 829.0735
Epoch 134/160
97/97 [==============================] - 29s 294ms/step - loss: 830.1376
Epoch 135/160
97/97 [==============================] - 29s 294ms/step - loss: 827.7781
Epoch 136/160
97/97 [==============================] - 29s 294ms/step - loss: 825.4308
Epoch 137/160
97/97 [==============================] - 29s 294ms/step - loss: 823.2223
Epoch 138/160
97/97 [==============================] - 29s 294ms/step - loss: 821.3982
Epoch 139/160
97/97 [==============================] - 29s 294ms/step - loss: 821.0161
Epoch 140/160
97/97 [==============================] - 29s 294ms/step - loss: 816.7703
Epoch 141/160
97/97 [==============================] - 29s 294ms/step - loss: 814.1747
Epoch 142/160
97/97 [==============================] - 29s 294ms/step - loss: 813.5908
Epoch 143/160
97/97 [==============================] - 29s 294ms/step - loss: 814.3353
Epoch 144/160
97/97 [==============================] - 29s 295ms/step - loss: 807.3126
Epoch 145/160
97/97 [==============================] - 29s 294ms/step - loss: 811.9185
Epoch 146/160
97/97 [==============================] - 29s 294ms/step - loss: 808.0939
Epoch 147/160
97/97 [==============================] - 29s 294ms/step - loss: 806.7361
Epoch 148/160
97/97 [==============================] - 29s 294ms/step - loss: 804.6682
Epoch 149/160
97/97 [==============================] - 29s 294ms/step - loss: 801.5149
Epoch 150/160
97/97 [==============================] - 29s 294ms/step - loss: 803.6600
Epoch 151/160
97/97 [==============================] - 29s 294ms/step - loss: 799.9028
Epoch 152/160
97/97 [==============================] - 29s 294ms/step - loss: 801.5812
Epoch 153/160
97/97 [==============================] - 29s 294ms/step - loss: 791.5322
Epoch 154/160
97/97 [==============================] - 29s 294ms/step - loss: 795.5021
Epoch 155/160
97/97 [==============================] - 29s 294ms/step - loss: 795.7894
Epoch 156/160
97/97 [==============================] - 29s 294ms/step - loss: 794.7897
Epoch 157/160
97/97 [==============================] - 29s 294ms/step - loss: 794.8560
Epoch 158/160
97/97 [==============================] - 29s 294ms/step - loss: 791.5762
Epoch 159/160
97/97 [==============================] - 29s 294ms/step - loss: 784.3605
Epoch 160/160
97/97 [==============================] - 29s 294ms/step - loss: 781.7180
线性评估: 为了评估模型的性能,我们在最后添加一个线性全连接层,并冻结主模型的权重,仅让全连接层进行调整。如果模型确实学习到了东西,那么准确率将显著高于随机猜测的几率。
CIFAR-10的准确率: 本笔记本为64%。这比我们从随机猜测中得到的10%要好得多。
# 近似:这个barlow twins模型的准确率为64%。
xy_ds = (
tf.data.Dataset.from_tensor_slices((train_features, train_labels))
.shuffle(1000)
.batch(BATCH_SIZE, drop_remainder=True)
.prefetch(tf.data.AUTOTUNE)
)
test_ds = (
tf.data.Dataset.from_tensor_slices((test_features, test_labels))
.shuffle(1000)
.batch(BATCH_SIZE, drop_remainder=True)
.prefetch(tf.data.AUTOTUNE)
)
model = keras.models.Sequential(
[
bm.model,
keras.layers.Dense(
10, activation="softmax", kernel_regularizer=keras.regularizers.l2(0.02)
),
]
)
model.layers[0].trainable = False
linear_optimizer = tfa.optimizers.LAMB()
model.compile(
optimizer=linear_optimizer,
loss="sparse_categorical_crossentropy",
metrics=["accuracy"],
)
model.fit(xy_ds, epochs=35, validation_data=test_ds)