代码示例 / 计算机视觉 / Barlow Twins 用于对比自监督学习

Barlow Twins 用于对比自监督学习

作者: Abhiraam Eranti
创建日期: 21/4/11
最后修改: 21/20/12

在 Colab 中查看 GitHub 源代码

描述: Barlow Twins 的 Keras 实现(通过冗余减少的对比自监督学习)。


介绍

自监督学习(SSL)是一种相对新颖的技术,其中模型从未标记的数据中学习,通常在数据损坏或数据量很少时使用。SSL 的一个实际用途是创建从数据中学习的中间嵌入。这些嵌入基于数据集本身,相似的图像具有相似的嵌入,反之亦然。它们然后附加到模型的其余部分,模型使用这些嵌入作为信息,有效地学习并做出正确的预测。这些嵌入理想情况下应该包含尽可能多的信息和对数据的洞察,以便模型能够做出更好的预测。然而,常见的问题是模型创建的嵌入是冗余的。例如,如果两幅图像相似,模型将创建一串1,或者其他包含重复信息的值。这与一热编码或仅有一个位作为模型的表示并无不同;这违背了嵌入的目的,因为它们对数据集的学习程度并不高。对于其他方法,解决此问题的方法是仔细配置模型,使其尽量避免冗余。

Barlow Twins 是解决此问题的一种新方法;虽然其他解决方案主要解决不变性的第一个目标(相似的图像具有相似的嵌入),Barlow Twins 方法同样优先考虑减少冗余的目标。

它还具有比其他方法简单得多的优势,其模型架构是对称的,这意味着模型中的两个“孪生”执行相同的操作。它在 Imagenet 上几乎接近最先进的水平,甚至超过了像 SimCLR 这样的算法。

Barlow Twins 的一个缺点是它在很大程度上依赖于增强,没有增强会导致准确性显著下降。

总结:Barlow Twins 创建的表示是:

  • 不变的。
  • 不冗余,并携带尽可能多的数据集信息。

此外,它比其他方法更简单。

这个笔记本可以训练一个 Barlow Twins 模型,并在 CIFAR-10 数据集上达到 64% 的验证准确率。

image

高级理论

该模型接受同一图像的两个版本(具有不同增强)作为输入。然后它对每个图像进行预测,创建表示。接着,这些表示用于生成交叉相关矩阵。

交叉相关矩阵:

(pred_1.T @ pred_2) / batch_size

交叉相关矩阵衡量模型对两个增强版本的数据的预测所生成的两个表示之间输出神经元的相关性。理想情况下,如果两幅图像相同,交叉相关矩阵应该看起来像一个单位矩阵。

当这种情况发生时,表示:

  1. 是不变的。对角线显示每个表示的神经元与其对应增强的神经元之间的相关性。由于这两个版本来自同一图像,矩阵的对角线应显示它们之间存在强相关性。如果图像不同,则不应有对角线。
  2. 不显示冗余的迹象。如果神经元与非对角神经元显示相关性,这意味着它没有正确识别两个增强图像之间的相似性。这意味着它是冗余的。

以下是以伪代码方式理解的好方法(来自原始论文的信息):

c[i][i] = 1
c[i][j] = 0

其中:
  c 是交叉相关矩阵
  i 是一个表示的神经元的索引
  j 是第二个表示的神经元的索引

摘自原始论文:Barlow Twins: Self-Supervised Learning via Redundancy Reduction

参考文献

论文: Barlow Twins: Self-Supervised Learning via Redundancy Reduction

原始实现: facebookresearch/barlowtwins


设置

!pip install tensorflow-addons
import os

# 略微加快的改进,在第一个周期减少了30秒,在周期时间上减少了1-2秒
# 总体节省约5分钟的训练时间

# 为gpu私有分配两个线程,这样可以更快地完成更多操作
os.environ["TF_GPU_THREAD_MODE"] = "gpu_private"

import tensorflow as tf  # 框架
from tensorflow import keras  # 用于tf.keras
import tensorflow_addons as tfa  # LAMB优化器和gaussian_blur_2d函数
import numpy as np  # np.random.random
import matplotlib.pyplot as plt  # 图表
import datetime  # tensorboard日志命名

# XLA优化以获得更快的性能(节省总时间高达10-15分钟)
tf.config.optimizer.set_jit(True)
['Requirement already satisfied: tensorflow-addons in /usr/local/lib/python3.7/dist-packages (0.15.0)',
 'Requirement already satisfied: typeguard>=2.7 in /usr/local/lib/python3.7/dist-packages (from tensorflow-addons) (2.7.1)']

加载 CIFAR-10 数据集

[
    (train_features, train_labels),
    (test_features, test_labels),
] = keras.datasets.cifar10.load_data()

train_features = train_features / 255.0
test_features = test_features / 255.0

必要的超参数

# 数据集的批次大小
BATCH_SIZE = 512
# 图像的宽度和高度
IMAGE_SIZE = 32

数据增强工具

巴洛双胞胎算法在很大程度上依赖于数据增强。该方法的一个独特特点是,有时增强是以概率方式发生的。

增强方法

  • RandomToGrayscale: 以 20% 的概率随机将图像应用灰度
  • RandomColorJitter: 以 80% 的概率随机应用颜色抖动
  • RandomFlip: 以 50% 的概率随机水平翻转图像
  • RandomResizedCrop: 以 100% 的概率随机裁剪图像到随机大小然后调整大小。
  • RandomSolarize: 以 20% 的概率随机将图像应用太阳化
  • RandomBlur: 以 20% 的概率随机模糊图像
class Augmentation(keras.layers.Layer):
    """Base augmentation class.

    Base augmentation class. Contains the random_execute method.

    Methods:
        random_execute: method that returns true or false based
          on a probability. Used to determine whether an augmentation
          will be run.
    """

    def __init__(self):
        super().__init__()

    @tf.function
    def random_execute(self, prob: float) -> bool:
        """random_execute function.

        Arguments:
            prob: a float value from 0-1 that determines the
              probability.

        Returns:
            returns true or false based on the probability.
        """

        return tf.random.uniform([], minval=0, maxval=1) < prob


class RandomToGrayscale(Augmentation):
    """RandomToGrayscale class.

    RandomToGrayscale class. Randomly makes an image
    grayscaled based on the random_execute method. There
    is a 20% chance that an image will be grayscaled.

    Methods:
        call: method that grayscales an image 20% of
          the time.
    """

    @tf.function
    def call(self, x: tf.Tensor) -> tf.Tensor:
        """call function.

        Arguments:
            x: a tf.Tensor representing the image.

        Returns:
            returns a grayscaled version of the image 20% of the time
              and the original image 80% of the time.
        """

        if self.random_execute(0.2):
            x = tf.image.rgb_to_grayscale(x)
            x = tf.tile(x, [1, 1, 3])
        return x


class RandomColorJitter(Augmentation):
    """RandomColorJitter class.

    RandomColorJitter class. Randomly adds color jitter to an image.
    Color jitter means to add random brightness, contrast,
    saturation, and hue to an image. There is a 80% chance that an
    image will be randomly color-jittered.

    Methods:
        call: method that color-jitters an image 80% of
          the time.
    """

    @tf.function
    def call(self, x: tf.Tensor) -> tf.Tensor:
        """call function.

        Adds color jitter to image, including:
          Brightness change by a max-delta of 0.8
          Contrast change by a max-delta of 0.8
          Saturation change by a max-delta of 0.8
          Hue change by a max-delta of 0.2
        Originally, the same deltas of the original paper
        were used, but a performance boost of almost 2% was found
        when doubling them.

        Arguments:
            x: a tf.Tensor representing the image.

        Returns:
            returns a color-jittered version of the image 80% of the time
              and the original image 20% of the time.
        """

        if self.random_execute(0.8):
            x = tf.image.random_brightness(x, 0.8)
            x = tf.image.random_contrast(x, 0.4, 1.6)
            x = tf.image.random_saturation(x, 0.4, 1.6)
            x = tf.image.random_hue(x, 0.2)
        return x


class RandomFlip(Augmentation):
    """RandomFlip class.

    RandomFlip class. Randomly flips image horizontally. There is a 50%
    chance that an image will be randomly flipped.

    Methods:
        call: method that flips an image 50% of
          the time.
    """

    @tf.function
    def call(self, x: tf.Tensor) -> tf.Tensor:
        """call function.

        Randomly flips the image.

        Arguments:
            x: a tf.Tensor representing the image.

        Returns:
            returns a flipped version of the image 50% of the time
              and the original image 50% of the time.
        """

        if self.random_execute(0.5):
            x = tf.image.random_flip_left_right(x)
        return x


class RandomResizedCrop(Augmentation):
    """RandomResizedCrop class.

    RandomResizedCrop class. Randomly crop an image to a random size,
    then resize the image back to the original size.

    Attributes:
        image_size: The dimension of the image

    Methods:
        __call__: method that does random resize crop to the image.
    """

    def __init__(self, image_size):
        super().__init__()
        self.image_size = image_size

    def call(self, x: tf.Tensor) -> tf.Tensor:
        """call function.

        Does random resize crop by randomly cropping an image to a random
        size 75% - 100% the size of the image. Then resizes it.

        Arguments:
            x: a tf.Tensor representing the image.

        Returns:
            returns a randomly cropped image.
        """

        rand_size = tf.random.uniform(
            shape=[],
            minval=int(0.75 * self.image_size),
            maxval=1 * self.image_size,
            dtype=tf.int32,
        )

        crop = tf.image.random_crop(x, (rand_size, rand_size, 3))
        crop_resize = tf.image.resize(crop, (self.image_size, self.image_size))
        return crop_resize


class RandomSolarize(Augmentation):
    """RandomSolarize class.

    RandomSolarize class. Randomly solarizes an image.
    Solarization is when pixels accidentally flip to an inverted state.

    Methods:
        call: method that does random solarization 20% of the time.
    """

    @tf.function
    def call(self, x: tf.Tensor) -> tf.Tensor:
        """call function.

        Randomly solarizes the image.

        Arguments:
            x: a tf.Tensor representing the image.

        Returns:
            returns a solarized version of the image 20% of the time
              and the original image 80% of the time.
        """

        if self.random_execute(0.2):
            # flips abnormally low pixels to abnormally high pixels
            x = tf.where(x < 10, x, 255 - x)
        return x


class RandomBlur(Augmentation):
    """RandomBlur class.

    RandomBlur class. Randomly blurs an image.

    Methods:
        call: method that does random blur 20% of the time.
    """

    @tf.function
    def call(self, x: tf.Tensor) -> tf.Tensor:
        """call function.

        Randomly solarizes the image.

        Arguments:
            x: a tf.Tensor representing the image.

        Returns:
            returns a blurred version of the image 20% of the time
              and the original image 80% of the time.
        """

        if self.random_execute(0.2):
            s = np.random.random()
            return tfa.image.gaussian_filter2d(image=x, sigma=s)
        return x


class RandomAugmentor(keras.Model):
    """RandomAugmentor class.

    RandomAugmentor class. Chains all the augmentations into
    one pipeline.

    Attributes:
        image_size: An integer represing the width and height
          of the image. Designed to be used for square images.
        random_resized_crop: Instance variable representing the
          RandomResizedCrop layer.
        random_flip: Instance variable representing the
          RandomFlip layer.
        random_color_jitter: Instance variable representing the
          RandomColorJitter layer.
        random_blur: Instance variable representing the
          RandomBlur layer
        random_to_grayscale: Instance variable representing the
          RandomToGrayscale layer
        random_solarize: Instance variable representing the
          RandomSolarize layer

    Methods:
        call: chains layers in pipeline together
    """

    def __init__(self, image_size: int):
        super().__init__()

        self.image_size = image_size
        self.random_resized_crop = RandomResizedCrop(image_size)
        self.random_flip = RandomFlip()
        self.random_color_jitter = RandomColorJitter()
        self.random_blur = RandomBlur()
        self.random_to_grayscale = RandomToGrayscale()
        self.random_solarize = RandomSolarize()

    def call(self, x: tf.Tensor) -> tf.Tensor:
        x = self.random_resized_crop(x)
        x = self.random_flip(x)
        x = self.random_color_jitter(x)
        x = self.random_blur(x)
        x = self.random_to_grayscale(x)
        x = self.random_solarize(x)

        x = tf.clip_by_value(x, 0, 1)
        return x


bt_augmentor = RandomAugmentor(IMAGE_SIZE)

数据加载

一个创建Barlow双胞胎数据集的类。

该数据集由每个图像的两份副本组成,每个副本接受不同的 增强处理。

class BTDatasetCreator:
    """Barlow双胞胎数据集创建类。

    BTDatasetCreator类。负责创建
    Barlow双胞胎的数据集。

    属性:
        options: tf.data.Options,用于配置可能提高性能的设置。
        seed: 随机种子,用于洗牌。用于同步两个
          增强版本。
        augmentor: 用于增强的增强器。

    方法:
        __call__: 创建Barlow数据集。
        augmented_version: 创建数据集的一半。
    """

    def __init__(self, augmentor: RandomAugmentor, seed: int = 1024):
        self.options = tf.data.Options()
        self.options.threading.max_intra_op_parallelism = 1
        self.seed = seed
        self.augmentor = augmentor

    def augmented_version(self, ds: list) -> tf.data.Dataset:
        return (
            tf.data.Dataset.from_tensor_slices(ds)
            .shuffle(1000, seed=self.seed)
            .map(self.augmentor, num_parallel_calls=tf.data.AUTOTUNE)
            .batch(BATCH_SIZE, drop_remainder=True)
            .prefetch(tf.data.AUTOTUNE)
            .with_options(self.options)
        )

    def __call__(self, ds: list) -> tf.data.Dataset:
        a1 = self.augmented_version(ds)
        a2 = self.augmented_version(ds)

        return tf.data.Dataset.zip((a1, a2)).with_options(self.options)


augment_versions = BTDatasetCreator(bt_augmentor)(train_features)

查看数据集示例。

sample_augment_versions = iter(augment_versions)


def plot_values(batch: tuple):
    fig, axs = plt.subplots(3, 3)
    fig1, axs1 = plt.subplots(3, 3)

    fig.suptitle("增强 1")
    fig1.suptitle("增强 2")

    a1, a2 = batch

    # 在两个表上绘制图像
    for i in range(3):
        for j in range(3):
            # 改变(添加 / 255)
            axs[i][j].imshow(a1[3 * i + j])
            axs[i][j].axis("off")
            axs1[i][j].imshow(a2[3 * i + j])
            axs1[i][j].axis("off")

    plt.show()


plot_values(next(sample_augment_versions))

png

png


损失和模型的伪代码

以下部分遵循原作者的伪代码,其中包含模型和 损失函数(见下图)。还包含所用变量的参考。

pseudocode

参考:

y_a: 原始图像的第一个增强版本。
y_b: 原始图像的第二个增强版本。
z_a: y_a的模型表示(嵌入)。
z_b: y_b的模型表示(嵌入)。
z_a_norm: z_a的归一化。
z_b_norm: z_b的归一化。
c: 交叉相关矩阵。
c_diff: 损失的对角部分(不变性项)。
off_diag: 损失的非对角部分(冗余减少项)。

BarlowLoss:Barlow双胞胎模型的损失函数

Barlow双胞胎使用交叉相关矩阵作为其损失。损失函数有两个部分:

  • 不变性项(对角)。该部分用于使矩阵的对角线变为1。当这种情况发生时,矩阵表明图像是 相关的(相同)。
    • 损失函数从对角线减去1并平方值。
  • 冗余减少项(非对角)。在这里,Barlow双胞胎损失 函数旨在使这些值为零。如前所述,如果表示神经元与不在对角线上的值相关,则是冗余的。
    • 非对角部分被平方。

然后将这两个部分相加。

class BarlowLoss(keras.losses.Loss):
    """BarlowLoss 类。

    BarlowLoss 类。基于交叉相关矩阵创建损失函数。

    属性:
        batch_size: 数据集的批次大小
        lambda_amt: lambda 的值(用于 cross_corr_matrix_loss)

    方法:
        __init__: 获取实例变量
        call: 根据交叉相关矩阵获取损失
          make_diag_zeros: 用于计算损失函数的非对角部分;将对角线设为零
        cross_corr_matrix_loss: 基于交叉相关矩阵创建损失。
    """

    def __init__(self, batch_size: int):
        """__init__ 方法。

        获取实例变量

        参数:
            batch_size: 一个整数值,表示数据集的批次大小。用于交叉相关矩阵计算。
        """

        super().__init__()
        self.lambda_amt = 5e-3
        self.batch_size = batch_size

    def get_off_diag(self, c: tf.Tensor) -> tf.Tensor:
        """get_off_diag 方法。

        将交叉相关矩阵的对角线设为零。
        这用于损失函数的非对角部分,我们对非对角值取平方并求和。

        参数:
            c: 表示交叉相关矩阵的 tf.tensor

        返回:
            返回一个 tf.tensor,表示对角线为零的交叉相关矩阵。
        """

        zero_diag = tf.zeros(c.shape[-1])
        return tf.linalg.set_diag(c, zero_diag)

    def cross_corr_matrix_loss(self, c: tf.Tensor) -> tf.Tensor:
        """cross_corr_matrix_loss 方法。

        根据交叉相关矩阵获取损失。
        我们希望对角线为 1,其他所有值为零,以显示这两幅增强图像相似。

        损失函数过程:
        取交叉相关矩阵的对角线,减去 1,然后平方该值以避免负值。

        取 cc 矩阵的非对角部分(见 get_off_diag()),
        将这些值平方以消除负值并增加值,
        并乘以一个 lambda 值,使其与对角线的值相等(非对角线的值比对角线值多)

        将第一部分和第二部分相加,然后再相加。

        参数:
            c: 表示交叉相关矩阵的 tf.tensor

        返回:
            返回一个 tf.tensor,表示对角线为零的交叉相关矩阵。
        """

        # 将对角线减去一并平方(第一部分)
        c_diff = tf.pow(tf.linalg.diag_part(c) - 1, 2)

        # 取非对角线,平方并乘以 lambda(第二部分)
        off_diag = tf.pow(self.get_off_diag(c), 2) * self.lambda_amt

        # 将第一部分和第二部分相加
        loss = tf.reduce_sum(c_diff) + tf.reduce_sum(off_diag)

        return loss

    def normalize(self, output: tf.Tensor) -> tf.Tensor:
        """normalize 方法。

        归一化模型预测。

        参数:
            output: 模型预测。

        返回:
            返回归一化后的模型预测。
        """

        return (output - tf.reduce_mean(output, axis=0)) / tf.math.reduce_std(
            output, axis=0
        )

    def cross_corr_matrix(self, z_a_norm: tf.Tensor, z_b_norm: tf.Tensor) -> tf.Tensor:
        """cross_corr_matrix 方法。

        从预测中创建交叉相关矩阵。
        它转置第一个预测并与第二个预测相乘,创建形状为 (n_dense_units, n_dense_units) 的矩阵。
        有关更多信息,请参见 build_twin()。然后将其除以批次大小。

        参数:
            z_a_norm: 第一个预测的归一化版本。
            z_b_norm: 第二个预测的归一化版本。

        返回:
            返回一个交叉相关矩阵。
        """
        return (tf.transpose(z_a_norm) @ z_b_norm) / self.batch_size

    def call(self, z_a: tf.Tensor, z_b: tf.Tensor) -> tf.Tensor:
        """call 方法。

        计算交叉相关损失。使用 CreateCrossCorr 类生成交叉相关矩阵,然后找到损失并返回(见 cross_corr_matrix_loss())。

        参数:
            z_a: 第一组增强数据的预测。
            z_b: 第二组增强数据的预测。

        返回:
            返回一个(秩为 0 的)tf.Tensor,表示损失。
        """

        z_a_norm, z_b_norm = self.normalize(z_a), self.normalize(z_b)
        c = self.cross_corr_matrix(z_a_norm, z_b_norm)
        loss = self.cross_corr_matrix_loss(c)
        return loss

Barlow Twins的模型架构

模型有两个部分:

  • 编码器网络,使用resnet-34。
  • 投影网络,用于生成模型的嵌入。
    • 由一个包含3层密集-batchnorm-relu的MLP组成。

Resnet编码器网络实现:

class ResNet34:
    """Resnet34类。

        负责Resnet 34架构。
    修改自
    https://www.analyticsvidhya.com/blog/2021/08/how-to-code-your-resnet-from-scratch-in-tensorflow/#h2_2。
    https://www.analyticsvidhya.com/blog/2021/08/how-to-code-your-resnet-from-scratch-in-tensorflow/#h2_2。
        更多信息请查看他们的网站。
    """

    def identity_block(self, x, filter):
        # 将张量复制到名为x_skip的变量
        x_skip = x
        # 层1
        x = tf.keras.layers.Conv2D(filter, (3, 3), padding="same")(x)
        x = tf.keras.layers.BatchNormalization(axis=3)(x)
        x = tf.keras.layers.Activation("relu")(x)
        # 层2
        x = tf.keras.layers.Conv2D(filter, (3, 3), padding="same")(x)
        x = tf.keras.layers.BatchNormalization(axis=3)(x)
        # 添加残差
        x = tf.keras.layers.Add()([x, x_skip])
        x = tf.keras.layers.Activation("relu")(x)
        return x

    def convolutional_block(self, x, filter):
        # 将张量复制到名为x_skip的变量
        x_skip = x
        # 层1
        x = tf.keras.layers.Conv2D(filter, (3, 3), padding="same", strides=(2, 2))(x)
        x = tf.keras.layers.BatchNormalization(axis=3)(x)
        x = tf.keras.layers.Activation("relu")(x)
        # 层2
        x = tf.keras.layers.Conv2D(filter, (3, 3), padding="same")(x)
        x = tf.keras.layers.BatchNormalization(axis=3)(x)
        # 使用conv(1,1)处理残差
        x_skip = tf.keras.layers.Conv2D(filter, (1, 1), strides=(2, 2))(x_skip)
        # 添加残差
        x = tf.keras.layers.Add()([x, x_skip])
        x = tf.keras.layers.Activation("relu")(x)
        return x

    def __call__(self, shape=(32, 32, 3)):
        # 步骤1(设置输入层)
        x_input = tf.keras.layers.Input(shape)
        x = tf.keras.layers.ZeroPadding2D((3, 3))(x_input)
        # 步骤2(初始卷积层和最大池化层)
        x = tf.keras.layers.Conv2D(64, kernel_size=7, strides=2, padding="same")(x)
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.Activation("relu")(x)
        x = tf.keras.layers.MaxPool2D(pool_size=3, strides=2, padding="same")(x)
        # 定义子块的大小和初始滤波器大小
        block_layers = [3, 4, 6, 3]
        filter_size = 64
        # 步骤3 添加Resnet块
        for i in range(4):
            if i == 0:
                # 对于子块1不需要残差/卷积块
                for j in range(block_layers[i]):
                    x = self.identity_block(x, filter_size)
            else:
                # 一个残差/卷积块后跟身份块
                # 滤波器大小将以2的倍数增加
                filter_size = filter_size * 2
                x = self.convolutional_block(x, filter_size)
                for j in range(block_layers[i] - 1):
                    x = self.identity_block(x, filter_size)
        # 步骤4 结束密集网络
        x = tf.keras.layers.AveragePooling2D((2, 2), padding="same")(x)
        x = tf.keras.layers.Flatten()(x)
        model = tf.keras.models.Model(inputs=x_input, outputs=x, name="ResNet34")
        return model

投影网络:

def build_twin() -> keras.Model:
    """build_twin方法。

    构建一个包含编码器(resnet-34)和投影器的barlow twins模型,
    用于为图像生成嵌入

    返回:
        返回一个barlow twins模型
    """

    # 投影器中的稠密神经元数量
    n_dense_neurons = 5000

    # 编码器网络
    resnet = ResNet34()()
    last_layer = resnet.layers[-1].output

    # 投影网络的中间层
    n_layers = 2
    for i in range(n_layers):
        dense = tf.keras.layers.Dense(n_dense_neurons, name=f"projector_dense_{i}")
        if i == 0:
            x = dense(last_layer)
        else:
            x = dense(x)
        x = tf.keras.layers.BatchNormalization(name=f"projector_bn_{i}")(x)
        x = tf.keras.layers.ReLU(name=f"projector_relu_{i}")(x)

    x = tf.keras.layers.Dense(n_dense_neurons, name=f"projector_dense_{n_layers}")(x)

    model = keras.Model(resnet.input, x)
    return model

训练循环模型

参见伪代码以供参考。

class BarlowModel(keras.Model):
    """BarlowModel 类。

    BarlowModel 类。负责进行预测和处理
    优化器的梯度下降。

    属性:
        model: barlow 模型架构。
        loss_tracker: 损失指标。

    方法:
        train_step: 一个训练步骤;进行模型预测、计算损失和
            优化器步骤。
        metrics: 返回指标。
    """

    def __init__(self):
        super().__init__()
        self.model = build_twin()
        self.loss_tracker = keras.metrics.Mean(name="loss")

    @property
    def metrics(self):
        return [self.loss_tracker]

    def train_step(self, batch: tf.Tensor) -> tf.Tensor:
        """train_step 方法。

        进行一个训练步骤。进行模型预测,计算损失,将损失传递给
        优化器,并使优化器应用梯度。

        参数:
            batch: 一批数据,用于损失函数。

        返回:
            返回一个包含损失指标的字典。
        """

        # 从批次中获取两个增强版本
        y_a, y_b = batch

        with tf.GradientTape() as tape:
            # 获取两个版本的预测
            z_a, z_b = self.model(y_a, training=True), self.model(y_b, training=True)
            loss = self.loss(z_a, z_b)

        grads_model = tape.gradient(loss, self.model.trainable_variables)

        self.optimizer.apply_gradients(zip(grads_model, self.model.trainable_variables))
        self.loss_tracker.update_state(loss)

        return {"loss": self.loss_tracker.result()}

模型训练

  • 使用了 LAMB 优化器,而不是 ADAM 或 SGD。
  • 与论文中使用的 LARS 优化器类似,让模型比其他方法收敛得更快。
  • 预计训练时间:1 小时 30 分钟。去吃点零食或者小睡一下。
# 设置模型、优化器、损失

bm = BarlowModel()
# 选择 LAMB 优化器是因为批大小较大。收敛得比 ADAM 或 SGD 快得多
optimizer = tfa.optimizers.LAMB()
loss = BarlowLoss(BATCH_SIZE)

bm.compile(optimizer=optimizer, loss=loss)

# 预计训练时间:1 小时 30 分钟

history = bm.fit(augment_versions, epochs=160)
plt.plot(history.history["loss"])
plt.show()
Epoch 1/160
97/97 [==============================] - 89s 294ms/step - loss: 3480.7588
Epoch 2/160
97/97 [==============================] - 29s 294ms/step - loss: 2163.4197
Epoch 3/160
97/97 [==============================] - 29s 294ms/step - loss: 1939.0248
Epoch 4/160
97/97 [==============================] - 29s 294ms/step - loss: 1810.4800
Epoch 5/160
97/97 [==============================] - 29s 294ms/step - loss: 1725.7401
Epoch 6/160
97/97 [==============================] - 29s 294ms/step - loss: 1658.2261
Epoch 7/160
97/97 [==============================] - 29s 294ms/step - loss: 1592.0747
Epoch 8/160
97/97 [==============================] - 29s 294ms/step - loss: 1545.2579
Epoch 9/160
97/97 [==============================] - 29s 294ms/step - loss: 1509.6631
Epoch 10/160
97/97 [==============================] - 29s 294ms/step - loss: 1484.1141
Epoch 11/160
97/97 [==============================] - 29s 293ms/step - loss: 1456.8615
Epoch 12/160
97/97 [==============================] - 29s 294ms/step - loss: 1430.0315
Epoch 13/160
97/97 [==============================] - 29s 294ms/step - loss: 1418.1147
Epoch 14/160
97/97 [==============================] - 29s 294ms/step - loss: 1385.7473
Epoch 15/160
97/97 [==============================] - 29s 294ms/step - loss: 1362.8176
Epoch 16/160
97/97 [==============================] - 29s 294ms/step - loss: 1353.6069
Epoch 17/160
97/97 [==============================] - 29s 294ms/step - loss: 1331.3687
Epoch 18/160
97/97 [==============================] - 29s 294ms/step - loss: 1323.1509
Epoch 19/160
97/97 [==============================] - 29s 294ms/step - loss: 1309.3015
Epoch 20/160
97/97 [==============================] - 29s 294ms/step - loss: 1303.2418
Epoch 21/160
97/97 [==============================] - 29s 294ms/step - loss: 1278.0450
Epoch 22/160
97/97 [==============================] - 29s 294ms/step - loss: 1272.2640
Epoch 23/160
97/97 [==============================] - 29s 294ms/step - loss: 1259.4225
Epoch 24/160
97/97 [==============================] - 29s 294ms/step - loss: 1246.8461
Epoch 25/160
97/97 [==============================] - 29s 294ms/step - loss: 1235.0269
Epoch 26/160
97/97 [==============================] - 29s 295ms/step - loss: 1228.4196
Epoch 27/160
97/97 [==============================] - 29s 295ms/step - loss: 1220.0851
Epoch 28/160
97/97 [==============================] - 29s 294ms/step - loss: 1208.5876
Epoch 29/160
97/97 [==============================] - 29s 294ms/step - loss: 1203.1449
Epoch 30/160
97/97 [==============================] - 29s 294ms/step - loss: 1199.5155
Epoch 31/160
97/97 [==============================] - 29s 294ms/step - loss: 1183.9818
Epoch 32/160
97/97 [==============================] - 29s 294ms/step - loss: 1173.9989
Epoch 33/160
97/97 [==============================] - 29s 294ms/step - loss: 1171.3789
Epoch 34/160
97/97 [==============================] - 29s 294ms/step - loss: 1160.8230
Epoch 35/160
97/97 [==============================] - 29s 294ms/step - loss: 1159.4148
Epoch 36/160
97/97 [==============================] - 29s 294ms/step - loss: 1148.4250
Epoch 37/160
97/97 [==============================] - 29s 294ms/step - loss: 1138.1802
Epoch 38/160
97/97 [==============================] - 29s 294ms/step - loss: 1135.9139
Epoch 39/160
97/97 [==============================] - 29s 294ms/step - loss: 1126.8186
Epoch 40/160
97/97 [==============================] - 29s 294ms/step - loss: 1119.6173
Epoch 41/160
97/97 [==============================] - 29s 293ms/step - loss: 1113.9358
Epoch 42/160
97/97 [==============================] - 29s 294ms/step - loss: 1106.0131
Epoch 43/160
97/97 [==============================] - 29s 294ms/step - loss: 1104.7386
Epoch 44/160
97/97 [==============================] - 29s 294ms/step - loss: 1097.7909
Epoch 45/160
97/97 [==============================] - 29s 294ms/step - loss: 1091.4229
Epoch 46/160
97/97 [==============================] - 29s 293ms/step - loss: 1082.3530
Epoch 47/160
97/97 [==============================] - 29s 294ms/step - loss: 1081.9459
Epoch 48/160
97/97 [==============================] - 29s 294ms/step - loss: 1078.5864
Epoch 49/160
97/97 [==============================] - 29s 293ms/step - loss: 1075.9255
Epoch 50/160
97/97 [==============================] - 29s 293ms/step - loss: 1070.9954
Epoch 51/160
97/97 [==============================] - 29s 294ms/step - loss: 1061.1058
Epoch 52/160
97/97 [==============================] - 29s 294ms/step - loss: 1055.0126
Epoch 53/160
97/97 [==============================] - 29s 294ms/step - loss: 1045.7827
Epoch 54/160
97/97 [==============================] - 29s 293ms/step - loss: 1047.5338
Epoch 55/160
97/97 [==============================] - 29s 294ms/step - loss: 1043.9012
Epoch 56/160
97/97 [==============================] - 29s 294ms/step - loss: 1044.5902
Epoch 57/160
97/97 [==============================] - 29s 294ms/step - loss: 1038.3389
Epoch 58/160
97/97 [==============================] - 29s 294ms/step - loss: 1032.1195
Epoch 59/160
97/97 [==============================] - 29s 294ms/step - loss: 1026.5962
Epoch 60/160
97/97 [==============================] - 29s 294ms/step - loss: 1018.2954
Epoch 61/160
97/97 [==============================] - 29s 294ms/step - loss: 1014.7681
Epoch 62/160
97/97 [==============================] - 29s 294ms/step - loss: 1007.7906
Epoch 63/160
97/97 [==============================] - 29s 294ms/step - loss: 1012.9134
Epoch 64/160
97/97 [==============================] - 29s 294ms/step - loss: 1009.7881
Epoch 65/160
97/97 [==============================] - 29s 294ms/step - loss: 1003.2436
Epoch 66/160
97/97 [==============================] - 29s 293ms/step - loss: 997.0688
Epoch 67/160
97/97 [==============================] - 29s 294ms/step - loss: 999.1620
Epoch 68/160
97/97 [==============================] - 29s 294ms/step - loss: 993.2636
Epoch 69/160
97/97 [==============================] - 29s 295ms/step - loss: 988.5142
Epoch 70/160
97/97 [==============================] - 29s 294ms/step - loss: 981.5876
Epoch 71/160
97/97 [==============================] - 29s 294ms/step - loss: 978.3053
Epoch 72/160
97/97 [==============================] - 29s 295ms/step - loss: 978.8599
Epoch 73/160
97/97 [==============================] - 29s 294ms/step - loss: 973.7569
Epoch 74/160
97/97 [==============================] - 29s 294ms/step - loss: 971.2402
Epoch 75/160
97/97 [==============================] - 29s 295ms/step - loss: 964.2864
Epoch 76/160
97/97 [==============================] - 29s 294ms/step - loss: 963.4999
Epoch 77/160
97/97 [==============================] - 29s 294ms/step - loss: 959.7264
Epoch 78/160
97/97 [==============================] - 29s 294ms/step - loss: 958.1680
Epoch 79/160
97/97 [==============================] - 29s 295ms/step - loss: 952.0243
Epoch 80/160
97/97 [==============================] - 29s 295ms/step - loss: 947.8354
Epoch 81/160
97/97 [==============================] - 29s 295ms/step - loss: 945.8139
Epoch 82/160
97/97 [==============================] - 29s 294ms/step - loss: 944.9114
Epoch 83/160
97/97 [==============================] - 29s 294ms/step - loss: 940.7040
Epoch 84/160
97/97 [==============================] - 29s 295ms/step - loss: 942.7839
Epoch 85/160
97/97 [==============================] - 29s 295ms/step - loss: 937.4374
Epoch 86/160
97/97 [==============================] - 29s 295ms/step - loss: 934.6262
Epoch 87/160
97/97 [==============================] - 29s 295ms/step - loss: 929.8491
Epoch 88/160
97/97 [==============================] - 29s 294ms/step - loss: 937.7441
Epoch 89/160
97/97 [==============================] - 29s 295ms/step - loss: 927.0290
Epoch 90/160
97/97 [==============================] - 29s 295ms/step - loss: 925.6105
Epoch 91/160
97/97 [==============================] - 29s 294ms/step - loss: 921.6296
Epoch 92/160
97/97 [==============================] - 29s 294ms/step - loss: 925.8184
Epoch 93/160
97/97 [==============================] - 29s 294ms/step - loss: 912.5261
Epoch 94/160
97/97 [==============================] - 29s 295ms/step - loss: 915.6510
Epoch 95/160
97/97 [==============================] - 29s 295ms/step - loss: 909.5853
Epoch 96/160
97/97 [==============================] - 29s 294ms/step - loss: 911.1563
Epoch 97/160
97/97 [==============================] - 29s 295ms/step - loss: 906.8965
Epoch 98/160
97/97 [==============================] - 29s 294ms/step - loss: 902.3696
Epoch 99/160
97/97 [==============================] - 29s 295ms/step - loss: 899.8710
Epoch 100/160
97/97 [==============================] - 29s 294ms/step - loss: 894.1641
Epoch 101/160
97/97 [==============================] - 29s 294ms/step - loss: 895.7336
Epoch 102/160
97/97 [==============================] - 29s 294ms/step - loss: 900.1674
Epoch 103/160
97/97 [==============================] - 29s 294ms/step - loss: 887.2552
Epoch 104/160
97/97 [==============================] - 29s 295ms/step - loss: 893.1448
Epoch 105/160
97/97 [==============================] - 29s 294ms/step - loss: 889.9379
Epoch 106/160
97/97 [==============================] - 29s 295ms/step - loss: 884.9587
Epoch 107/160
97/97 [==============================] - 29s 294ms/step - loss: 880.9834
Epoch 108/160
97/97 [==============================] - 29s 295ms/step - loss: 883.2829
Epoch 109/160
97/97 [==============================] - 29s 294ms/step - loss: 876.6734
Epoch 110/160
97/97 [==============================] - 29s 294ms/step - loss: 873.4252
Epoch 111/160
97/97 [==============================] - 29s 294ms/step - loss: 873.2639
Epoch 112/160
97/97 [==============================] - 29s 295ms/step - loss: 871.0381
Epoch 113/160
97/97 [==============================] - 29s 294ms/step - loss: 866.5417
Epoch 114/160
97/97 [==============================] - 29s 294ms/step - loss: 862.2125
Epoch 115/160
97/97 [==============================] - 29s 294ms/step - loss: 862.8839
Epoch 116/160
97/97 [==============================] - 29s 294ms/step - loss: 861.1781
Epoch 117/160
97/97 [==============================] - 29s 294ms/step - loss: 856.6186
Epoch 118/160
97/97 [==============================] - 29s 294ms/step - loss: 857.3196
Epoch 119/160
97/97 [==============================] - 29s 294ms/step - loss: 858.0576
Epoch 120/160
97/97 [==============================] - 29s 294ms/step - loss: 855.3264
Epoch 121/160
97/97 [==============================] - 29s 294ms/step - loss: 850.6841
Epoch 122/160
97/97 [==============================] - 29s 294ms/step - loss: 849.6420
Epoch 123/160
97/97 [==============================] - 29s 294ms/step - loss: 846.6933
Epoch 124/160
97/97 [==============================] - 29s 295ms/step - loss: 847.4681
Epoch 125/160
97/97 [==============================] - 29s 294ms/step - loss: 838.5893
Epoch 126/160
97/97 [==============================] - 29s 294ms/step - loss: 841.2516
Epoch 127/160
97/97 [==============================] - 29s 295ms/step - loss: 840.6940
Epoch 128/160
97/97 [==============================] - 29s 294ms/step - loss: 840.9053
Epoch 129/160
97/97 [==============================] - 29s 294ms/step - loss: 836.9998
Epoch 130/160
97/97 [==============================] - 29s 294ms/step - loss: 836.6874
Epoch 131/160
97/97 [==============================] - 29s 294ms/step - loss: 835.2166
Epoch 132/160
97/97 [==============================] - 29s 295ms/step - loss: 833.7071
Epoch 133/160
97/97 [==============================] - 29s 294ms/step - loss: 829.0735
Epoch 134/160
97/97 [==============================] - 29s 294ms/step - loss: 830.1376
Epoch 135/160
97/97 [==============================] - 29s 294ms/step - loss: 827.7781
Epoch 136/160
97/97 [==============================] - 29s 294ms/step - loss: 825.4308
Epoch 137/160
97/97 [==============================] - 29s 294ms/step - loss: 823.2223
Epoch 138/160
97/97 [==============================] - 29s 294ms/step - loss: 821.3982
Epoch 139/160
97/97 [==============================] - 29s 294ms/step - loss: 821.0161
Epoch 140/160
97/97 [==============================] - 29s 294ms/step - loss: 816.7703
Epoch 141/160
97/97 [==============================] - 29s 294ms/step - loss: 814.1747
Epoch 142/160
97/97 [==============================] - 29s 294ms/step - loss: 813.5908
Epoch 143/160
97/97 [==============================] - 29s 294ms/step - loss: 814.3353
Epoch 144/160
97/97 [==============================] - 29s 295ms/step - loss: 807.3126
Epoch 145/160
97/97 [==============================] - 29s 294ms/step - loss: 811.9185
Epoch 146/160
97/97 [==============================] - 29s 294ms/step - loss: 808.0939
Epoch 147/160
97/97 [==============================] - 29s 294ms/step - loss: 806.7361
Epoch 148/160
97/97 [==============================] - 29s 294ms/step - loss: 804.6682
Epoch 149/160
97/97 [==============================] - 29s 294ms/step - loss: 801.5149
Epoch 150/160
97/97 [==============================] - 29s 294ms/step - loss: 803.6600
Epoch 151/160
97/97 [==============================] - 29s 294ms/step - loss: 799.9028
Epoch 152/160
97/97 [==============================] - 29s 294ms/step - loss: 801.5812
Epoch 153/160
97/97 [==============================] - 29s 294ms/step - loss: 791.5322
Epoch 154/160
97/97 [==============================] - 29s 294ms/step - loss: 795.5021
Epoch 155/160
97/97 [==============================] - 29s 294ms/step - loss: 795.7894
Epoch 156/160
97/97 [==============================] - 29s 294ms/step - loss: 794.7897
Epoch 157/160
97/97 [==============================] - 29s 294ms/step - loss: 794.8560
Epoch 158/160
97/97 [==============================] - 29s 294ms/step - loss: 791.5762
Epoch 159/160
97/97 [==============================] - 29s 294ms/step - loss: 784.3605
Epoch 160/160
97/97 [==============================] - 29s 294ms/step - loss: 781.7180

png


评估

线性评估: 为了评估模型的性能,我们在最后添加一个线性全连接层,并冻结主模型的权重,仅让全连接层进行调整。如果模型确实学习到了东西,那么准确率将显著高于随机猜测的几率。

CIFAR-10的准确率: 本笔记本为64%。这比我们从随机猜测中得到的10%要好得多。

# 近似:这个barlow twins模型的准确率为64%。

xy_ds = (
    tf.data.Dataset.from_tensor_slices((train_features, train_labels))
    .shuffle(1000)
    .batch(BATCH_SIZE, drop_remainder=True)
    .prefetch(tf.data.AUTOTUNE)
)

test_ds = (
    tf.data.Dataset.from_tensor_slices((test_features, test_labels))
    .shuffle(1000)
    .batch(BATCH_SIZE, drop_remainder=True)
    .prefetch(tf.data.AUTOTUNE)
)

model = keras.models.Sequential(
    [
        bm.model,
        keras.layers.Dense(
            10, activation="softmax", kernel_regularizer=keras.regularizers.l2(0.02)
        ),
    ]
)

model.layers[0].trainable = False

linear_optimizer = tfa.optimizers.LAMB()
model.compile(
    optimizer=linear_optimizer,
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"],
)

model.fit(xy_ds, epochs=35, validation_data=test_ds)
``` Epoch 1/35 97/97 [==============================] - 12s 84ms/step - loss: 2.9447 - accuracy: 0.2090 - val_loss: 2.3056 - val_accuracy: 0.3741 Epoch 2/35 97/97 [==============================] - 6s 62ms/step - loss: 1.9912 - accuracy: 0.4867 - val_loss: 1.6910 - val_accuracy: 0.5883 Epoch 3/35 97/97 [==============================] - 6s 62ms/step - loss: 1.5476 - accuracy: 0.6278 - val_loss: 1.4605 - val_accuracy: 0.6465 Epoch 4/35 97/97 [==============================] - 6s 62ms/step - loss: 1.3775 - accuracy: 0.6647 - val_loss: 1.3689 - val_accuracy: 0.6644 Epoch 5/35 97/97 [==============================] - 6s 62ms/step - loss: 1.3027 - accuracy: 0.6769 - val_loss: 1.3232 - val_accuracy: 0.6684 Epoch 6/35 97/97 [==============================] - 6s 62ms/step - loss: 1.2574 - accuracy: 0.6820 - val_loss: 1.2905 - val_accuracy: 0.6717 Epoch 7/35 97/97 [==============================] - 6s 63ms/step - loss: 1.2244 - accuracy: 0.6852 - val_loss: 1.2654 - val_accuracy: 0.6742 Epoch 8/35 97/97 [==============================] - 6s 62ms/step - loss: 1.1979 - accuracy: 0.6868 - val_loss: 1.2460 - val_accuracy: 0.6747 Epoch 9/35 97/97 [==============================] - 6s 62ms/step - loss: 1.1754 - accuracy: 0.6884 - val_loss: 1.2247 - val_accuracy: 0.6773 Epoch 10/35 97/97 [==============================] - 6s 62ms/step - loss: 1.1559 - accuracy: 0.6896 - val_loss: 1.2090 - val_accuracy: 0.6770 Epoch 11/35 97/97 [==============================] - 6s 62ms/step - loss: 1.1380 - accuracy: 0.6907 - val_loss: 1.1904 - val_accuracy: 0.6785 Epoch 12/35 97/97 [==============================] - 6s 62ms/step - loss: 1.1223 - accuracy: 0.6915 - val_loss: 1.1796 - val_accuracy: 0.6776 Epoch 13/35 97/97 [==============================] - 6s 62ms/step - loss: 1.1079 - accuracy: 0.6923 - val_loss: 1.1696 - val_accuracy: 0.6785 Epoch 14/35 97/97 [==============================] - 6s 62ms/step - loss: 1.0954 - accuracy: 0.6931 - val_loss: 1.1564 - val_accuracy: 0.6795 Epoch 15/35 97/97 [==============================] - 6s 63ms/step - loss: 1.0841 - accuracy: 0.6939 - val_loss: 1.1454 - val_accuracy: 0.6807 Epoch 16/35 97/97 [==============================] - 6s 62ms/step - loss: 1.0733 - accuracy: 0.6945 - val_loss: 1.1356 - val_accuracy: 0.6810 Epoch 17/35 97/97 [==============================] - 6s 62ms/step - loss: 1.0634 - accuracy: 0.6948 - val_loss: 1.1313 - val_accuracy: 0.6799 Epoch 18/35 97/97 [==============================] - 6s 63ms/step - loss: 1.0535 - accuracy: 0.6957 - val_loss: 1.1208 - val_accuracy: 0.6808 Epoch 19/35 97/97 [==============================] - 6s 63ms/step - loss: 1.0447 - accuracy: 0.6965 - val_loss: 1.1128 - val_accuracy: 0.6813 Epoch 20/35 97/97 [==============================] - 6s 62ms/step - loss: 1.0366 - accuracy: 0.6968 - val_loss: 1.1082 - val_accuracy: 0.6799 Epoch 21/35 97/97 [==============================] - 6s 62ms/step - loss: 1.0295 - accuracy: 0.6968 - val_loss: 1.0971 - val_accuracy: 0.6821 Epoch 22/35 97/97 [==============================] - 6s 63ms/step - loss: 1.0226 - accuracy: 0.6971 - val_loss: 1.0946 - val_accuracy: 0.6799 Epoch 23/35 97/97 [==============================] - 6s 62ms/step - loss: 1.0166 - accuracy: 0.6977 - val_loss: 1.0916 - val_accuracy: 0.6802 Epoch 24/35 97/97 [==============================] - 6s 63ms/step - loss: 1.0103 - accuracy: 0.6980 - val_loss: 1.0823 - val_accuracy: 0.6819 Epoch 25/35 97/97 [==============================] - 6s 62ms/step - loss: 1.0052 - accuracy: 0.6981 - val_loss: 1.0795 - val_accuracy: 0.6804 Epoch 26/35 97/97 [==============================] - 6s 63ms/step - loss: 1.0001 - accuracy: 0.6984 - val_loss: 1.0759 - val_accuracy: 0.6806 Epoch 27/35 97/97 [==============================] - 6s 62ms/step - loss: 0.9947 - accuracy: 0.6992 - val_loss: 1.0699 - val_accuracy: 0.6809 Epoch 28/35 97/97 [==============================] - 6s 62ms/step - loss: 0.9901 - accuracy: 0.6987 - val_loss: 1.0637 - val_accuracy: 0.6821 Epoch 29/35 97/97 [==============================] - 6s 63ms/step - loss: 0.9862 - accuracy: 0.6991 - val_loss: 1.0603 - val_accuracy: 0.6826 Epoch 30/35 97/97 [==============================] - 6s 63ms/step - loss: 0.9817 - accuracy: 0.6994 - val_loss: 1.0582 - val_accuracy: 0.6813 Epoch 31/35 97/97 [==============================] - 6s 63ms/step - loss: 0.9784 - accuracy: 0.6994 - val_loss: 1.0531 - val_accuracy: 0.6826 Epoch 32/35 97/97 [==============================] - 6s 62ms/step - loss: 0.9743 - accuracy: 0.6998 - val_loss: 1.0505 - val_accuracy: 0.6822 Epoch 33/35 97/97 [==============================] - 6s 62ms/step - loss: 0.9711 - accuracy: 0.6996 - val_loss: 1.0506 - val_accuracy: 0.6800 Epoch 34/35 97/97 [==============================] - 6s 62ms/step - loss: 0.9686 - accuracy: 0.6993 - val_loss: 1.0423 - val_accuracy: 0.6828 Epoch 35/35 97/97 [==============================] - 6s 62ms/step - loss: 0.9653 - accuracy: 0.6999 - val_loss: 1.0429 - val_accuracy: 0.6821

结论

  • Barlow Twins 是一种简单而简洁的对比学习和自监督学习方法。
  • 通过这个 resnet-34 模型架构,我们能够达到 62-64% 的验证准确率。

Barlow-Twins 的应用案例(以及一般的对比学习)

  • 半监督学习:你可以看到,当这个模型没有使用标签进行训练时,它的准确率提升了 62-64%。当你有少量标记数据但大量未标记数据时,可以使用它。
  • 你可以在未标记数据上进行 Barlow Twins 训练,然后再使用标记数据进行二次训练。

有用的链接