代码示例 / 计算机视觉 / Barlow Twins 用于对比自监督学习

Barlow Twins 用于对比自监督学习

作者: Abhiraam Eranti
创建日期: 21/4/11
最后修改: 21/20/12

描述: Barlow Twins 的 Keras 实现(通过冗余减少的对比自监督学习)。


自监督学习(SSL)是一种相对新颖的技术,其中模型从未标记的数据中学习,通常在数据损坏或数据量很少时使用。SSL 的一个实际用途是创建从数据中学习的中间嵌入。这些嵌入基于数据集本身,相似的图像具有相似的嵌入,反之亦然。它们然后附加到模型的其余部分,模型使用这些嵌入作为信息,有效地学习并做出正确的预测。这些嵌入理想情况下应该包含尽可能多的信息和对数据的洞察,以便模型能够做出更好的预测。然而,常见的问题是模型创建的嵌入是冗余的。例如,如果两幅图像相似,模型将创建一串1,或者其他包含重复信息的值。这与一热编码或仅有一个位作为模型的表示并无不同;这违背了嵌入的目的,因为它们对数据集的学习程度并不高。对于其他方法,解决此问题的方法是仔细配置模型,使其尽量避免冗余。

Barlow Twins 是解决此问题的一种新方法;虽然其他解决方案主要解决不变性的第一个目标(相似的图像具有相似的嵌入),Barlow Twins 方法同样优先考虑减少冗余的目标。

它还具有比其他方法简单得多的优势,其模型架构是对称的,这意味着模型中的两个“孪生”执行相同的操作。它在 Imagenet 上几乎接近最先进的水平,甚至超过了像 SimCLR 这样的算法。

Barlow Twins 的一个缺点是它在很大程度上依赖于增强,没有增强会导致准确性显著下降。

总结:Barlow Twins 创建的表示是:

  • 不变的。
  • 不冗余,并携带尽可能多的数据集信息。


这个笔记本可以训练一个 Barlow Twins 模型,并在 CIFAR-10 数据集上达到 64% 的验证准确率。





(pred_1.T @ pred_2) / batch_size



  1. 是不变的。对角线显示每个表示的神经元与其对应增强的神经元之间的相关性。由于这两个版本来自同一图像,矩阵的对角线应显示它们之间存在强相关性。如果图像不同,则不应有对角线。
  2. 不显示冗余的迹象。如果神经元与非对角神经元显示相关性,这意味着它没有正确识别两个增强图像之间的相似性。这意味着它是冗余的。


c[i][i] = 1
c[i][j] = 0

  c 是交叉相关矩阵
  i 是一个表示的神经元的索引
  j 是第二个表示的神经元的索引

摘自原始论文:Barlow Twins: Self-Supervised Learning via Redundancy Reduction


论文: Barlow Twins: Self-Supervised Learning via Redundancy Reduction

原始实现: facebookresearch/barlowtwins


!pip install tensorflow-addons
import os

# 略微加快的改进,在第一个周期减少了30秒,在周期时间上减少了1-2秒
# 总体节省约5分钟的训练时间

# 为gpu私有分配两个线程,这样可以更快地完成更多操作
os.environ["TF_GPU_THREAD_MODE"] = "gpu_private"

import tensorflow as tf  # 框架
from tensorflow import keras  # 用于tf.keras
import tensorflow_addons as tfa  # LAMB优化器和gaussian_blur_2d函数
import numpy as np  # np.random.random
import matplotlib.pyplot as plt  # 图表
import datetime  # tensorboard日志命名

# XLA优化以获得更快的性能(节省总时间高达10-15分钟)
加载 CIFAR-10 数据集

    (train_features, train_labels),
    (test_features, test_labels),
] = keras.datasets.cifar10.load_data()

train_features = train_features / 255.0
test_features = test_features / 255.0


# 数据集的批次大小
# 图像的宽度和高度




  • RandomToGrayscale: 以 20% 的概率随机将图像应用灰度
  • RandomColorJitter: 以 80% 的概率随机应用颜色抖动
  • RandomFlip: 以 50% 的概率随机水平翻转图像
  • RandomResizedCrop: 以 100% 的概率随机裁剪图像到随机大小然后调整大小。
  • RandomSolarize: 以 20% 的概率随机将图像应用太阳化
  • RandomBlur: 以 20% 的概率随机模糊图像
class Augmentation(keras.layers.Layer):
    """Base augmentation class.

    Base augmentation class. Contains the random_execute method.

        random_execute: method that returns true or false based
          on a probability. Used to determine whether an augmentation
          will be run.

    def __init__(self):

    def random_execute(self, prob: float) -> bool:
        """random_execute function.

            prob: a float value from 0-1 that determines the

            returns true or false based on the probability.

        return tf.random.uniform([], minval=0, maxval=1) < prob

class RandomToGrayscale(Augmentation):
    """RandomToGrayscale class.

    RandomToGrayscale class. Randomly makes an image
    grayscaled based on the random_execute method. There
    is a 20% chance that an image will be grayscaled.

        call: method that grayscales an image 20% of
          the time.

    def call(self, x: tf.Tensor) -> tf.Tensor:
        """call function.

            x: a tf.Tensor representing the image.

            returns a grayscaled version of the image 20% of the time
              and the original image 80% of the time.

        if self.random_execute(0.2):
            x = tf.image.rgb_to_grayscale(x)
            x = tf.tile(x, [1, 1, 3])
        return x

class RandomColorJitter(Augmentation):
    """RandomColorJitter class.

    RandomColorJitter class. Randomly adds color jitter to an image.
    Color jitter means to add random brightness, contrast,
    saturation, and hue to an image. There is a 80% chance that an
    image will be randomly color-jittered.

        call: method that color-jitters an image 80% of
          the time.

    def call(self, x: tf.Tensor) -> tf.Tensor:
        """call function.

        Adds color jitter to image, including:
          Brightness change by a max-delta of 0.8
          Contrast change by a max-delta of 0.8
          Saturation change by a max-delta of 0.8
          Hue change by a max-delta of 0.2
        Originally, the same deltas of the original paper
        were used, but a performance boost of almost 2% was found
        when doubling them.

            x: a tf.Tensor representing the image.

            returns a color-jittered version of the image 80% of the time
              and the original image 20% of the time.

        if self.random_execute(0.8):
            x = tf.image.random_brightness(x, 0.8)
            x = tf.image.random_contrast(x, 0.4, 1.6)
            x = tf.image.random_saturation(x, 0.4, 1.6)
            x = tf.image.random_hue(x, 0.2)
        return x

class RandomFlip(Augmentation):
    """RandomFlip class.

    RandomFlip class. Randomly flips image horizontally. There is a 50%
    chance that an image will be randomly flipped.

        call: method that flips an image 50% of
          the time.

    def call(self, x: tf.Tensor) -> tf.Tensor:
        """call function.

        Randomly flips the image.

            x: a tf.Tensor representing the image.

            returns a flipped version of the image 50% of the time
              and the original image 50% of the time.

        if self.random_execute(0.5):
            x = tf.image.random_flip_left_right(x)
        return x

class RandomResizedCrop(Augmentation):
    """RandomResizedCrop class.

    RandomResizedCrop class. Randomly crop an image to a random size,
    then resize the image back to the original size.

        image_size: The dimension of the image

        __call__: method that does random resize crop to the image.

    def __init__(self, image_size):
        self.image_size = image_size

    def call(self, x: tf.Tensor) -> tf.Tensor:
        """call function.

        Does random resize crop by randomly cropping an image to a random
        size 75% - 100% the size of the image. Then resizes it.

            x: a tf.Tensor representing the image.

            returns a randomly cropped image.

        rand_size = tf.random.uniform(
            minval=int(0.75 * self.image_size),
            maxval=1 * self.image_size,

        crop = tf.image.random_crop(x, (rand_size, rand_size, 3))
        crop_resize = tf.image.resize(crop, (self.image_size, self.image_size))
        return crop_resize

class RandomSolarize(Augmentation):
    """RandomSolarize class.

    RandomSolarize class. Randomly solarizes an image.
    Solarization is when pixels accidentally flip to an inverted state.

        call: method that does random solarization 20% of the time.

    def call(self, x: tf.Tensor) -> tf.Tensor:
        """call function.

        Randomly solarizes the image.

            x: a tf.Tensor representing the image.

            returns a solarized version of the image 20% of the time
              and the original image 80% of the time.

        if self.random_execute(0.2):
            # flips abnormally low pixels to abnormally high pixels
            x = tf.where(x < 10, x, 255 - x)
        return x

class RandomBlur(Augmentation):
    """RandomBlur class.

    RandomBlur class. Randomly blurs an image.

        call: method that does random blur 20% of the time.

    def call(self, x: tf.Tensor) -> tf.Tensor:
        """call function.

        Randomly solarizes the image.

            x: a tf.Tensor representing the image.

            returns a blurred version of the image 20% of the time
              and the original image 80% of the time.

        if self.random_execute(0.2):
            s = np.random.random()
            return tfa.image.gaussian_filter2d(image=x, sigma=s)
        return x

class RandomAugmentor(keras.Model):
    """RandomAugmentor class.

    RandomAugmentor class. Chains all the augmentations into
    one pipeline.

        image_size: An integer represing the width and height
          of the image. Designed to be used for square images.
        random_resized_crop: Instance variable representing the
          RandomResizedCrop layer.
        random_flip: Instance variable representing the
          RandomFlip layer.
        random_color_jitter: Instance variable representing the
          RandomColorJitter layer.
        random_blur: Instance variable representing the
          RandomBlur layer
        random_to_grayscale: Instance variable representing the
          RandomToGrayscale layer
        random_solarize: Instance variable representing the
          RandomSolarize layer

        call: chains layers in pipeline together

    def __init__(self, image_size: int):

        self.image_size = image_size
        self.random_resized_crop = RandomResizedCrop(image_size)
        self.random_flip = RandomFlip()
        self.random_color_jitter = RandomColorJitter()
        self.random_blur = RandomBlur()
        self.random_to_grayscale = RandomToGrayscale()
        self.random_solarize = RandomSolarize()

    def call(self, x: tf.Tensor) -> tf.Tensor:
        x = self.random_resized_crop(x)
        x = self.random_flip(x)
        x = self.random_color_jitter(x)
        x = self.random_blur(x)
        x = self.random_to_grayscale(x)
        x = self.random_solarize(x)

        x = tf.clip_by_value(x, 0, 1)
        return x

bt_augmentor = RandomAugmentor(IMAGE_SIZE)



该数据集由每个图像的两份副本组成,每个副本接受不同的 增强处理。

class BTDatasetCreator:


        options: tf.data.Options,用于配置可能提高性能的设置。
        seed: 随机种子,用于洗牌。用于同步两个
        augmentor: 用于增强的增强器。

        __call__: 创建Barlow数据集。
        augmented_version: 创建数据集的一半。

    def __init__(self, augmentor: RandomAugmentor, seed: int = 1024):
        self.options = tf.data.Options()
        self.options.threading.max_intra_op_parallelism = 1
        self.seed = seed
        self.augmentor = augmentor

    def augmented_version(self, ds: list) -> tf.data.Dataset:
        return (
            .shuffle(1000, seed=self.seed)
            .map(self.augmentor, num_parallel_calls=tf.data.AUTOTUNE)
            .batch(BATCH_SIZE, drop_remainder=True)

    def __call__(self, ds: list) -> tf.data.Dataset:
        a1 = self.augmented_version(ds)
        a2 = self.augmented_version(ds)

        return tf.data.Dataset.zip((a1, a2)).with_options(self.options)

augment_versions = BTDatasetCreator(bt_augmentor)(train_features)


sample_augment_versions = iter(augment_versions)

def plot_values(batch: tuple):
    fig, axs = plt.subplots(3, 3)
    fig1, axs1 = plt.subplots(3, 3)

    fig.suptitle("增强 1")
    fig1.suptitle("增强 2")

    a1, a2 = batch

    # 在两个表上绘制图像
    for i in range(3):
        for j in range(3):
            # 改变(添加 / 255)
            axs[i][j].imshow(a1[3 * i + j])
            axs1[i][j].imshow(a2[3 * i + j])






以下部分遵循原作者的伪代码,其中包含模型和 损失函数(见下图)。还包含所用变量的参考。



y_a: 原始图像的第一个增强版本。
y_b: 原始图像的第二个增强版本。
z_a: y_a的模型表示(嵌入)。
z_b: y_b的模型表示(嵌入)。
z_a_norm: z_a的归一化。
z_b_norm: z_b的归一化。
c: 交叉相关矩阵。
c_diff: 损失的对角部分(不变性项)。
off_diag: 损失的非对角部分(冗余减少项)。



  • 不变性项(对角)。该部分用于使矩阵的对角线变为1。当这种情况发生时,矩阵表明图像是 相关的(相同)。
    • 损失函数从对角线减去1并平方值。
  • 冗余减少项(非对角)。在这里,Barlow双胞胎损失 函数旨在使这些值为零。如前所述,如果表示神经元与不在对角线上的值相关,则是冗余的。
    • 非对角部分被平方。


class BarlowLoss(keras.losses.Loss):
    """BarlowLoss 类。

    BarlowLoss 类。基于交叉相关矩阵创建损失函数。

        batch_size: 数据集的批次大小
        lambda_amt: lambda 的值(用于 cross_corr_matrix_loss)

        __init__: 获取实例变量
        call: 根据交叉相关矩阵获取损失
          make_diag_zeros: 用于计算损失函数的非对角部分;将对角线设为零
        cross_corr_matrix_loss: 基于交叉相关矩阵创建损失。

    def __init__(self, batch_size: int):
        """__init__ 方法。


            batch_size: 一个整数值,表示数据集的批次大小。用于交叉相关矩阵计算。

        self.lambda_amt = 5e-3
        self.batch_size = batch_size

    def get_off_diag(self, c: tf.Tensor) -> tf.Tensor:
        """get_off_diag 方法。


            c: 表示交叉相关矩阵的 tf.tensor

            返回一个 tf.tensor,表示对角线为零的交叉相关矩阵。

        zero_diag = tf.zeros(c.shape[-1])
        return tf.linalg.set_diag(c, zero_diag)

    def cross_corr_matrix_loss(self, c: tf.Tensor) -> tf.Tensor:
        """cross_corr_matrix_loss 方法。

        我们希望对角线为 1,其他所有值为零,以显示这两幅增强图像相似。

        取交叉相关矩阵的对角线,减去 1,然后平方该值以避免负值。

        取 cc 矩阵的非对角部分(见 get_off_diag()),
        并乘以一个 lambda 值,使其与对角线的值相等(非对角线的值比对角线值多)


            c: 表示交叉相关矩阵的 tf.tensor

            返回一个 tf.tensor,表示对角线为零的交叉相关矩阵。

        # 将对角线减去一并平方(第一部分)
        c_diff = tf.pow(tf.linalg.diag_part(c) - 1, 2)

        # 取非对角线,平方并乘以 lambda(第二部分)
        off_diag = tf.pow(self.get_off_diag(c), 2) * self.lambda_amt

        # 将第一部分和第二部分相加
        loss = tf.reduce_sum(c_diff) + tf.reduce_sum(off_diag)

        return loss

    def normalize(self, output: tf.Tensor) -> tf.Tensor:
        """normalize 方法。


            output: 模型预测。


        return (output - tf.reduce_mean(output, axis=0)) / tf.math.reduce_std(
            output, axis=0

    def cross_corr_matrix(self, z_a_norm: tf.Tensor, z_b_norm: tf.Tensor) -> tf.Tensor:
        """cross_corr_matrix 方法。

        它转置第一个预测并与第二个预测相乘,创建形状为 (n_dense_units, n_dense_units) 的矩阵。
        有关更多信息,请参见 build_twin()。然后将其除以批次大小。

            z_a_norm: 第一个预测的归一化版本。
            z_b_norm: 第二个预测的归一化版本。

        return (tf.transpose(z_a_norm) @ z_b_norm) / self.batch_size

    def call(self, z_a: tf.Tensor, z_b: tf.Tensor) -> tf.Tensor:
        """call 方法。

        计算交叉相关损失。使用 CreateCrossCorr 类生成交叉相关矩阵,然后找到损失并返回(见 cross_corr_matrix_loss())。

            z_a: 第一组增强数据的预测。
            z_b: 第二组增强数据的预测。

            返回一个(秩为 0 的)tf.Tensor,表示损失。

        z_a_norm, z_b_norm = self.normalize(z_a), self.normalize(z_b)
        c = self.cross_corr_matrix(z_a_norm, z_b_norm)
        loss = self.cross_corr_matrix_loss(c)
        return loss

Barlow Twins的模型架构


  • 编码器网络,使用resnet-34。
  • 投影网络,用于生成模型的嵌入。
    • 由一个包含3层密集-batchnorm-relu的MLP组成。


class ResNet34:

        负责Resnet 34架构。

    def identity_block(self, x, filter):
        # 将张量复制到名为x_skip的变量
        x_skip = x
        # 层1
        x = tf.keras.layers.Conv2D(filter, (3, 3), padding="same")(x)
        x = tf.keras.layers.BatchNormalization(axis=3)(x)
        x = tf.keras.layers.Activation("relu")(x)
        # 层2
        x = tf.keras.layers.Conv2D(filter, (3, 3), padding="same")(x)
        x = tf.keras.layers.BatchNormalization(axis=3)(x)
        # 添加残差
        x = tf.keras.layers.Add()([x, x_skip])
        x = tf.keras.layers.Activation("relu")(x)
        return x

    def convolutional_block(self, x, filter):
        # 将张量复制到名为x_skip的变量
        x_skip = x
        # 层1
        x = tf.keras.layers.Conv2D(filter, (3, 3), padding="same", strides=(2, 2))(x)
        x = tf.keras.layers.BatchNormalization(axis=3)(x)
        x = tf.keras.layers.Activation("relu")(x)
        # 层2
        x = tf.keras.layers.Conv2D(filter, (3, 3), padding="same")(x)
        x = tf.keras.layers.BatchNormalization(axis=3)(x)
        # 使用conv(1,1)处理残差
        x_skip = tf.keras.layers.Conv2D(filter, (1, 1), strides=(2, 2))(x_skip)
        # 添加残差
        x = tf.keras.layers.Add()([x, x_skip])
        x = tf.keras.layers.Activation("relu")(x)
        return x

    def __call__(self, shape=(32, 32, 3)):
        # 步骤1(设置输入层)
        x_input = tf.keras.layers.Input(shape)
        x = tf.keras.layers.ZeroPadding2D((3, 3))(x_input)
        # 步骤2(初始卷积层和最大池化层)
        x = tf.keras.layers.Conv2D(64, kernel_size=7, strides=2, padding="same")(x)
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.Activation("relu")(x)
        x = tf.keras.layers.MaxPool2D(pool_size=3, strides=2, padding="same")(x)
        # 定义子块的大小和初始滤波器大小
        block_layers = [3, 4, 6, 3]
        filter_size = 64
        # 步骤3 添加Resnet块
        for i in range(4):
            if i == 0:
                # 对于子块1不需要残差/卷积块
                for j in range(block_layers[i]):
                    x = self.identity_block(x, filter_size)
                # 一个残差/卷积块后跟身份块
                # 滤波器大小将以2的倍数增加
                filter_size = filter_size * 2
                x = self.convolutional_block(x, filter_size)
                for j in range(block_layers[i] - 1):
                    x = self.identity_block(x, filter_size)
        # 步骤4 结束密集网络
        x = tf.keras.layers.AveragePooling2D((2, 2), padding="same")(x)
        x = tf.keras.layers.Flatten()(x)
        model = tf.keras.models.Model(inputs=x_input, outputs=x, name="ResNet34")
        return model


def build_twin() -> keras.Model:

    构建一个包含编码器(resnet-34)和投影器的barlow twins模型,

        返回一个barlow twins模型

    # 投影器中的稠密神经元数量
    n_dense_neurons = 5000

    # 编码器网络
    resnet = ResNet34()()
    last_layer = resnet.layers[-1].output

    # 投影网络的中间层
    n_layers = 2
    for i in range(n_layers):
        dense = tf.keras.layers.Dense(n_dense_neurons, name=f"projector_dense_{i}")
        if i == 0:
            x = dense(last_layer)
            x = dense(x)
        x = tf.keras.layers.BatchNormalization(name=f"projector_bn_{i}")(x)
        x = tf.keras.layers.ReLU(name=f"projector_relu_{i}")(x)

    x = tf.keras.layers.Dense(n_dense_neurons, name=f"projector_dense_{n_layers}")(x)

    model = keras.Model(resnet.input, x)
    return model



class BarlowModel(keras.Model):
    """BarlowModel 类。

    BarlowModel 类。负责进行预测和处理

        model: barlow 模型架构。
        loss_tracker: 损失指标。

        train_step: 一个训练步骤;进行模型预测、计算损失和
        metrics: 返回指标。

    def __init__(self):
        self.model = build_twin()
        self.loss_tracker = keras.metrics.Mean(name="loss")

    def metrics(self):
        return [self.loss_tracker]

    def train_step(self, batch: tf.Tensor) -> tf.Tensor:
        """train_step 方法。


            batch: 一批数据,用于损失函数。


        # 从批次中获取两个增强版本
        y_a, y_b = batch

        with tf.GradientTape() as tape:
            # 获取两个版本的预测
            z_a, z_b = self.model(y_a, training=True), self.model(y_b, training=True)
            loss = self.loss(z_a, z_b)

        grads_model = tape.gradient(loss, self.model.trainable_variables)

        self.optimizer.apply_gradients(zip(grads_model, self.model.trainable_variables))

        return {"loss": self.loss_tracker.result()}


  • 使用了 LAMB 优化器,而不是 ADAM 或 SGD。
  • 与论文中使用的 LARS 优化器类似,让模型比其他方法收敛得更快。
  • 预计训练时间:1 小时 30 分钟。去吃点零食或者小睡一下。
# 设置模型、优化器、损失

bm = BarlowModel()
# 选择 LAMB 优化器是因为批大小较大。收敛得比 ADAM 或 SGD 快得多
optimizer = tfa.optimizers.LAMB()
loss = BarlowLoss(BATCH_SIZE)

bm.compile(optimizer=optimizer, loss=loss)

# 预计训练时间:1 小时 30 分钟

history = bm.fit(augment_versions, epochs=160)
线性评估: 为了评估模型的性能,我们在最后添加一个线性全连接层,并冻结主模型的权重,仅让全连接层进行调整。如果模型确实学习到了东西,那么准确率将显著高于随机猜测的几率。

CIFAR-10的准确率: 本笔记本为64%。这比我们从随机猜测中得到的10%要好得多。

# 近似:这个barlow twins模型的准确率为64%。

xy_ds = (
    tf.data.Dataset.from_tensor_slices((train_features, train_labels))
    .batch(BATCH_SIZE, drop_remainder=True)

test_ds = (
    tf.data.Dataset.from_tensor_slices((test_features, test_labels))
    .batch(BATCH_SIZE, drop_remainder=True)

model = keras.models.Sequential(
            10, activation="softmax", kernel_regularizer=keras.regularizers.l2(0.02)

model.layers[0].trainable = False

linear_optimizer = tfa.optimizers.LAMB()

model.fit(xy_ds, epochs=35, validation_data=test_ds)
``` Epoch 1/35 97/97 [==============================] - 12s 84ms/step - loss: 2.9447 - accuracy: 0.2090 - val_loss: 2.3056 - val_accuracy: 0.3741 Epoch 2/35 97/97 [==============================] - 6s 62ms/step - loss: 1.9912 - accuracy: 0.4867 - val_loss: 1.6910 - val_accuracy: 0.5883 Epoch 3/35 97/97 [==============================] - 6s 62ms/step - loss: 1.5476 - accuracy: 0.6278 - val_loss: 1.4605 - val_accuracy: 0.6465 Epoch 4/35 97/97 [==============================] - 6s 62ms/step - loss: 1.3775 - accuracy: 0.6647 - val_loss: 1.3689 - val_accuracy: 0.6644 Epoch 5/35 97/97 [==============================] - 6s 62ms/step - loss: 1.3027 - accuracy: 0.6769 - val_loss: 1.3232 - val_accuracy: 0.6684 Epoch 6/35 97/97 [==============================] - 6s 62ms/step - loss: 1.2574 - accuracy: 0.6820 - val_loss: 1.2905 - val_accuracy: 0.6717 Epoch 7/35 97/97 [==============================] - 6s 63ms/step - loss: 1.2244 - accuracy: 0.6852 - val_loss: 1.2654 - val_accuracy: 0.6742 Epoch 8/35 97/97 [==============================] - 6s 62ms/step - loss: 1.1979 - accuracy: 0.6868 - val_loss: 1.2460 - val_accuracy: 0.6747 Epoch 9/35 97/97 [==============================] - 6s 62ms/step - loss: 1.1754 - accuracy: 0.6884 - val_loss: 1.2247 - val_accuracy: 0.6773 Epoch 10/35 97/97 [==============================] - 6s 62ms/step - loss: 1.1559 - accuracy: 0.6896 - val_loss: 1.2090 - val_accuracy: 0.6770 Epoch 11/35 97/97 [==============================] - 6s 62ms/step - loss: 1.1380 - accuracy: 0.6907 - val_loss: 1.1904 - val_accuracy: 0.6785 Epoch 12/35 97/97 [==============================] - 6s 62ms/step - loss: 1.1223 - accuracy: 0.6915 - val_loss: 1.1796 - val_accuracy: 0.6776 Epoch 13/35 97/97 [==============================] - 6s 62ms/step - loss: 1.1079 - accuracy: 0.6923 - val_loss: 1.1696 - val_accuracy: 0.6785 Epoch 14/35 97/97 [==============================] - 6s 62ms/step - loss: 1.0954 - accuracy: 0.6931 - val_loss: 1.1564 - val_accuracy: 0.6795 Epoch 15/35 97/97 [==============================] - 6s 63ms/step - loss: 1.0841 - accuracy: 0.6939 - val_loss: 1.1454 - val_accuracy: 0.6807 Epoch 16/35 97/97 [==============================] - 6s 62ms/step - loss: 1.0733 - accuracy: 0.6945 - val_loss: 1.1356 - val_accuracy: 0.6810 Epoch 17/35 97/97 [==============================] - 6s 62ms/step - loss: 1.0634 - accuracy: 0.6948 - val_loss: 1.1313 - val_accuracy: 0.6799 Epoch 18/35 97/97 [==============================] - 6s 63ms/step - loss: 1.0535 - accuracy: 0.6957 - val_loss: 1.1208 - val_accuracy: 0.6808 Epoch 19/35 97/97 [==============================] - 6s 63ms/step - loss: 1.0447 - accuracy: 0.6965 - val_loss: 1.1128 - val_accuracy: 0.6813 Epoch 20/35 97/97 [==============================] - 6s 62ms/step - loss: 1.0366 - accuracy: 0.6968 - val_loss: 1.1082 - val_accuracy: 0.6799 Epoch 21/35 97/97 [==============================] - 6s 62ms/step - loss: 1.0295 - accuracy: 0.6968 - val_loss: 1.0971 - val_accuracy: 0.6821 Epoch 22/35 97/97 [==============================] - 6s 63ms/step - loss: 1.0226 - accuracy: 0.6971 - val_loss: 1.0946 - val_accuracy: 0.6799 Epoch 23/35 97/97 [==============================] - 6s 62ms/step - loss: 1.0166 - accuracy: 0.6977 - val_loss: 1.0916 - val_accuracy: 0.6802 Epoch 24/35 97/97 [==============================] - 6s 63ms/step - loss: 1.0103 - accuracy: 0.6980 - val_loss: 1.0823 - val_accuracy: 0.6819 Epoch 25/35 97/97 [==============================] - 6s 62ms/step - loss: 1.0052 - accuracy: 0.6981 - val_loss: 1.0795 - val_accuracy: 0.6804 Epoch 26/35 97/97 [==============================] - 6s 63ms/step - loss: 1.0001 - accuracy: 0.6984 - val_loss: 1.0759 - val_accuracy: 0.6806 Epoch 27/35 97/97 [==============================] - 6s 62ms/step - loss: 0.9947 - accuracy: 0.6992 - val_loss: 1.0699 - val_accuracy: 0.6809 Epoch 28/35 97/97 [==============================] - 6s 62ms/step - loss: 0.9901 - accuracy: 0.6987 - val_loss: 1.0637 - val_accuracy: 0.6821 Epoch 29/35 97/97 [==============================] - 6s 63ms/step - loss: 0.9862 - accuracy: 0.6991 - val_loss: 1.0603 - val_accuracy: 0.6826 Epoch 30/35 97/97 [==============================] - 6s 63ms/step - loss: 0.9817 - accuracy: 0.6994 - val_loss: 1.0582 - val_accuracy: 0.6813 Epoch 31/35 97/97 [==============================] - 6s 63ms/step - loss: 0.9784 - accuracy: 0.6994 - val_loss: 1.0531 - val_accuracy: 0.6826 Epoch 32/35 97/97 [==============================] - 6s 62ms/step - loss: 0.9743 - accuracy: 0.6998 - val_loss: 1.0505 - val_accuracy: 0.6822 Epoch 33/35 97/97 [==============================] - 6s 62ms/step - loss: 0.9711 - accuracy: 0.6996 - val_loss: 1.0506 - val_accuracy: 0.6800 Epoch 34/35 97/97 [==============================] - 6s 62ms/step - loss: 0.9686 - accuracy: 0.6993 - val_loss: 1.0423 - val_accuracy: 0.6828 Epoch 35/35 97/97 [==============================] - 6s 62ms/step - loss: 0.9653 - accuracy: 0.6999 - val_loss: 1.0429 - val_accuracy: 0.6821


  • Barlow Twins 是一种简单而简洁的对比学习和自监督学习方法。
  • 通过这个 resnet-34 模型架构,我们能够达到 62-64% 的验证准确率。

Barlow-Twins 的应用案例(以及一般的对比学习)

  • 半监督学习:你可以看到,当这个模型没有使用标签进行训练时,它的准确率提升了 62-64%。当你有少量标记数据但大量未标记数据时,可以使用它。
  • 你可以在未标记数据上进行 Barlow Twins 训练,然后再使用标记数据进行二次训练。
