培训概述

为什么要微调?

微调 Sentence Transformer 模型通常会显著提升模型在你使用场景中的性能,因为每个任务需要的相似性概念不同。例如,给定新闻文章:

  • 苹果发布新款iPad

  • NVIDIA 正在为下一代 GPU 做准备

然后,以下用例中,我们可能对相似性有不同的概念:

  • 一个用于将新闻文章分类为经济、体育、科技、政治等的模型,应为这些文本生成相似的嵌入。

  • 一个用于语义文本相似性的模型应该为这些文本生成不同的嵌入,因为它们具有不同的含义。

  • 语义搜索的模型不需要两个文档之间相似性的概念,因为它应该只比较查询和文档。

另请参阅训练示例,其中包含许多适用于常见实际应用的训练脚本,您可以采用这些脚本。

训练组件

训练 Sentence Transformer 模型涉及 3 到 5 个组件:

数据集

使用 :class:SentenceTransformerTrainer 进行训练和评估,可以使用 :class:datasets.Dataset(一个数据集)或 :class:datasets.DatasetDict 实例(多个数据集,参见 多数据集训练 <#multi-dataset-training>_)。

如果你想从 Hugging Face 数据集 <https://huggingface.co/datasets>_ 加载数据,那么你应该使用 :func:datasets.load_dataset:

from datasets import load_dataset

train_dataset = load_dataset("sentence-transformers/all-nli", "pair-class", split="train")
eval_dataset = load_dataset("sentence-transformers/all-nli", "pair-class", split="dev")

print(train_dataset)
"""
Dataset({
    features: ['premise', 'hypothesis', 'label'],
    num_rows: 942069
})
"""

一些数据集(包括 sentence-transformers/all-nli <https://huggingface.co/datasets/sentence-transformers/all-nli>)要求你在数据集名称旁边提供一个“子集”。sentence-transformers/all-nli 有 4 个子集,每个子集具有不同的数据格式:pair <https://huggingface.co/datasets/sentence-transformers/all-nli/viewer/pair>pair-class <https://huggingface.co/datasets/sentence-transformers/all-nli/viewer/pair-class>pair-score <https://huggingface.co/datasets/sentence-transformers/all-nli/viewer/pair-score>triplet <https://huggingface.co/datasets/sentence-transformers/all-nli/viewer/triplet>_。

备注

许多与 Sentence Transformers 开箱即用的 Hugging Face 数据集已被标记为 sentence-transformers,使您可以通过浏览 https://huggingface.co/datasets?other=sentence-transformers <https://huggingface.co/datasets?other=sentence-transformers>_ 轻松找到它们。我们强烈建议您浏览这些数据集,以找到可能对您的任务有用的训练数据集。

如果你有本地数据在常见的文件格式中,那么你可以使用 :func:datasets.load_dataset 轻松加载这些数据:

from datasets import load_dataset

dataset = load_dataset("csv", data_files="my_file.csv")

或者::

from datasets import load_dataset

dataset = load_dataset("json", data_files="my_file.json")

如果你有需要额外预处理的本地的数据,我的建议是使用 :meth:datasets.Dataset.from_dict 和列表的字典来初始化你的数据集,如下所示:

from datasets import Dataset

anchors = []
positives = []
# Open a file, do preprocessing, filtering, cleaning, etc.
# and append to the lists

dataset = Dataset.from_dict({
    "anchor": anchors,
    "positive": positives,
})

字典中的每个键将成为结果数据集中的一个列。

数据集格式

确保你的数据集格式与损失函数匹配(或者选择一个与数据集格式匹配的损失函数)非常重要。验证数据集格式是否适用于损失函数涉及两个步骤:

  1. 如果你的损失函数需要一个 标签 ,根据 损失概览 <loss_overview.html>_ 表,那么你的数据集必须有一个 名为“label”或“score”的列。该列将自动作为标签。

  2. 所有未命名为“label”或“score”的列根据 Loss Overview <loss_overview.html>_ 表被视为 Inputs。剩余列的数量必须与所选损失的有效输入数量匹配。这些列的名称是 无关紧要的,只有 顺序重要

例如,给定一个包含列 ["text1", "text2", "label"] 的数据集,其中 "label" 列包含浮点相似度分数,我们可以将其与 :class:~sentence_transformers.losses.CoSENTLoss、:class:~sentence_transformers.losses.AnglELoss 和 :class:~sentence_transformers.losses.CosineSimilarityLoss 一起使用,因为它:

  1. 有一个“标签”列,这是这些损失函数所要求的。

  2. 有2个非标签列,正好是这些损失函数所需的列数。

如果你的数据列顺序不正确,请务必使用 :meth:Dataset.select_columns <datasets.Dataset.select_columns> 重新排列你的数据集列。例如,如果你的数据集列是 ["good_answer", "bad_answer", "question"],那么这个数据集在技术上可以用于需要 (锚点, 正样本, 负样本) 三元组的损失函数,但 good_answer 列将被视为锚点,bad_answer 视为正样本,而 question 视为负样本。

此外,如果你的数据集包含无关的列(例如 sample_id、metadata、source、type),你应该使用 :meth:Dataset.remove_columns <datasets.Dataset.remove_columns> 删除这些列,否则它们将被用作输入。你也可以使用 :meth:Dataset.select_columns <datasets.Dataset.select_columns> 只保留所需的列。

损失函数

损失函数量化了模型对给定批次数据的执行情况,允许优化器更新模型权重以产生更有利的(即更低的)损失值。这是训练过程的核心。

遗憾的是,没有一种损失函数适用于所有用例。相反,使用哪种损失函数在很大程度上取决于你可用的数据和你的目标任务。请参阅数据集格式以了解哪些数据集对哪些损失函数有效。此外,损失概述将成为你了解选项的最佳伙伴。

大多数损失函数只需使用你正在训练的 :class:SentenceTransformer 进行初始化,同时可以伴随一些可选参数,例如:

from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from sentence_transformers.losses import CoSENTLoss

# Load a model to train/finetune
model = SentenceTransformer("xlm-roberta-base")

# Initialize the CoSENTLoss
# This loss requires pairs of text and a float similarity score as a label
loss = CoSENTLoss(model)

# Load an example training dataset that works with our loss function:
train_dataset = load_dataset("sentence-transformers/all-nli", "pair-score", split="train")
"""
Dataset({
    features: ['sentence1', 'sentence2', 'label'],
    num_rows: 942069
})
"""

训练参数

可以使用 :class:~sentence_transformers.training_args.SentenceTransformerTrainingArguments 类来指定影响训练性能的参数以及定义跟踪/调试参数。虽然它是可选的,但强烈建议尝试各种有用的参数。

以下是一些最有用的训练参数的表格。



以下是如何初始化 :class:~sentence_transformers.training_args.SentenceTransformerTrainingArguments 的示例:

args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir="models/mpnet-base-all-nli-triplet",
    # Optional training parameters:
    num_train_epochs=1,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    learning_rate=2e-5,
    warmup_ratio=0.1,
    fp16=True,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=False,  # Set to True if you have a GPU that supports BF16
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # losses that use "in-batch negatives" benefit from no duplicates
    # Optional tracking/debugging parameters:
    eval_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=2,
    logging_steps=100,
    run_name="mpnet-base-all-nli-triplet",  # Will be used in W&B if `wandb` is installed
)

评估器

你可以为 SentenceTransformerTrainer 提供一个 eval_dataset 以在训练期间获取评估损失,但在训练期间获取更多具体的指标也可能是有用的。为此,你可以使用评估器在训练前、训练中或训练后评估模型的性能,并使用有用的指标。你可以同时使用 eval_dataset 和评估器,或者只使用其中一个,或者两者都不使用。它们根据 eval_strategy 和 eval_steps 训练参数进行评估。

以下是Sentence Transformers附带的已实现评估器:

评估器

所需数据

BinaryClassificationEvaluator

与类标签配对

EmbeddingSimilarityEvaluator

配对及其相似度得分

InformationRetrievalEvaluator

查询(qid => 问题),语料库(cid => 文档),以及相关文档(qid => 集合[cid])

MSEEvaluator

源句子嵌入教师模型,目标句子嵌入学生模型。可以是相同的文本。

ParaphraseMiningEvaluator

ID 到句子及其重复句子 ID 对的映射。

RerankingEvaluator

{'query': '...', 'positive': [...], 'negative': [...]} 字典列表。

TranslationEvaluator

两种不同语言中的句子对。

TripletEvaluator

(锚点,正样本,负样本)三元组。

此外,应使用 :class:~sentence_transformers.evaluation.SequentialEvaluator 将多个评估器组合成一个可以传递给 :class:~sentence_transformers.trainer.SentenceTransformerTrainer 的评估器。

有时你没有所需的评估数据来自己准备这些评估器,但你仍然希望跟踪模型在某些常见基准上的表现。在这种情况下,你可以使用这些评估器并结合Hugging Face的数据。

from datasets import load_dataset
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, SimilarityFunction

# Load the STSB dataset (https://huggingface.co/datasets/sentence-transformers/stsb)
eval_dataset = load_dataset("sentence-transformers/stsb", split="validation")

# Initialize the evaluator
dev_evaluator = EmbeddingSimilarityEvaluator(
    sentences1=eval_dataset["sentence1"],
    sentences2=eval_dataset["sentence2"],
    scores=eval_dataset["score"],
    main_similarity=SimilarityFunction.COSINE,
    name="sts-dev",
)
# You can run evaluation like so:
# dev_evaluator(model)
from datasets import load_dataset
from sentence_transformers.evaluation import TripletEvaluator, SimilarityFunction

# Load triplets from the AllNLI dataset (https://huggingface.co/datasets/sentence-transformers/all-nli)
max_samples = 1000
eval_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split=f"dev[:{max_samples}]")

# Initialize the evaluator
dev_evaluator = TripletEvaluator(
    anchors=eval_dataset["anchor"],
    positives=eval_dataset["positive"],
    negatives=eval_dataset["negative"],
    main_distance_function=SimilarityFunction.COSINE,
    name="all-nli-dev",
)
# You can run evaluation like so:
# dev_evaluator(model)

警告

在使用 分布式训练 <training/distributed.html>_ 时,评估器仅在第一个设备上运行,与训练和评估数据集不同,后者在所有设备之间共享。

训练师

The :class:~sentence_transformers.SentenceTransformerTrainer 是所有先前组件的集合地。我们只需要用模型、训练参数(可选)、训练数据集、评估数据集(可选)、损失函数、评估器(可选)来指定训练器,就可以开始训练了。让我们来看一个所有这些组件结合在一起的脚本:

from datasets import load_dataset
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
    SentenceTransformerModelCardData,
)
from sentence_transformers.losses import MultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.evaluation import TripletEvaluator

# 1. Load a model to finetune with 2. (Optional) model card data
model = SentenceTransformer(
    "microsoft/mpnet-base",
    model_card_data=SentenceTransformerModelCardData(
        language="en",
        license="apache-2.0",
        model_name="MPNet base trained on AllNLI triplets",
    )
)

# 3. Load a dataset to finetune on
dataset = load_dataset("sentence-transformers/all-nli", "triplet")
train_dataset = dataset["train"].select(range(100_000))
eval_dataset = dataset["dev"]
test_dataset = dataset["test"]

# 4. Define a loss function
loss = MultipleNegativesRankingLoss(model)

# 5. (Optional) Specify training arguments
args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir="models/mpnet-base-all-nli-triplet",
    # Optional training parameters:
    num_train_epochs=1,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    learning_rate=2e-5,
    warmup_ratio=0.1,
    fp16=True,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=False,  # Set to True if you have a GPU that supports BF16
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
    # Optional tracking/debugging parameters:
    eval_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=2,
    logging_steps=100,
    run_name="mpnet-base-all-nli-triplet",  # Will be used in W&B if `wandb` is installed
)

# 6. (Optional) Create an evaluator & evaluate the base model
dev_evaluator = TripletEvaluator(
    anchors=eval_dataset["anchor"],
    positives=eval_dataset["positive"],
    negatives=eval_dataset["negative"],
    name="all-nli-dev",
)
dev_evaluator(model)

# 7. Create a trainer & train
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=loss,
    evaluator=dev_evaluator,
)
trainer.train()

# (Optional) Evaluate the trained model on the test set
test_evaluator = TripletEvaluator(
    anchors=test_dataset["anchor"],
    positives=test_dataset["positive"],
    negatives=test_dataset["negative"],
    name="all-nli-test",
)
test_evaluator(model)

# 8. Save the trained model
model.save_pretrained("models/mpnet-base-all-nli-triplet/final")

# 9. (Optional) Push it to the Hugging Face Hub
model.push_to_hub("mpnet-base-all-nli-triplet")

回调

这个 Sentence Transformers 训练器集成了对各种 :class:transformers.TrainerCallback 子类的支持,例如:

  • :class:~transformers.integrations.WandbCallback 用于在安装了 wandb 的情况下自动将训练指标记录到 W&B。

  • :class:~transformers.integrations.TensorBoardCallback 用于在 tensorboard 可访问时将训练指标记录到 TensorBoard。

  • :class:~transformers.integrations.CodeCarbonCallback 用于在安装了 codecarbon 的情况下,跟踪模型训练期间的碳排放。

    • 注意:这些碳排放将被包含在你自动生成的模型卡片中。

有关集成回调以及如何编写自己的回调的更多信息,请参阅 Transformers 的 Callbacks <https://huggingface.co/docs/transformers/main/en/main_classes/callback>_ 文档。

多数据集训练

表现最好的模型是通过同时使用多个数据集进行训练的。通常,这相当棘手,因为每个数据集都有不同的格式。然而,:class:SentenceTransformerTrainer 可以在不将每个数据集转换为相同格式的情况下训练多个数据集。它甚至可以对每个数据集应用不同的损失函数。使用多个数据集进行训练的步骤是:

  • 使用 :class:~datasets.Dataset 实例(或 :class:~datasets.DatasetDict)的字典作为 train_dataseteval_dataset

  • (可选)使用一个损失函数字典,将数据集名称映射到损失函数。仅在你希望对不同数据集使用不同损失函数时才需要。

每个训练/评估批次将仅包含来自一个数据集的样本。从多个数据集中抽取批次的顺序由 :class:~sentence_transformers.training_args.MultiDatasetBatchSamplers 枚举定义,可以通过 multi_dataset_batch_sampler 传递给 :class:~sentence_transformers.training_args.SentenceTransformerTrainingArguments。有效选项包括:

  • MultiDatasetBatchSamplers.ROUND_ROBIN: 从每个数据集中按顺序轮询采样,直到其中一个数据集耗尽。使用这种策略,可能不会使用每个数据集中的所有样本,但每个数据集都会被均匀采样。

  • MultiDatasetBatchSamplers.PROPORTIONAL (默认): 按数据集大小比例从每个数据集中采样。使用此策略,每个数据集的所有样本都会被使用,并且较大的数据集会被更频繁地采样。

这种多任务训练已被证明非常有效,例如 Huang et al. <https://arxiv.org/pdf/2405.06932>_ 使用了 :class:~sentence_transformers.losses.MultipleNegativesRankingLoss、:class:~sentence_transformers.losses.CoSENTLoss,以及一种没有批次内负样本和只有硬负样本的 :class:~sentence_transformers.losses.MultipleNegativesRankingLoss 变体,在中国达到了最先进的性能。他们甚至应用了 :class:~sentence_transformers.losses.MatryoshkaLoss,使模型能够生成 Matryoshka Embeddings <../../examples/training/matryoshka/README.html>_。

在多个数据集上进行训练看起来是这样的:

from datasets import load_dataset
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer
from sentence_transformers.losses import CoSENTLoss, MultipleNegativesRankingLoss, SoftmaxLoss

# 1. Load a model to finetune
model = SentenceTransformer("bert-base-uncased")

# 2. Load several Datasets to train with
# (anchor, positive)
all_nli_pair_train = load_dataset("sentence-transformers/all-nli", "pair", split="train[:10000]")
# (premise, hypothesis) + label
all_nli_pair_class_train = load_dataset("sentence-transformers/all-nli", "pair-class", split="train[:10000]")
# (sentence1, sentence2) + score
all_nli_pair_score_train = load_dataset("sentence-transformers/all-nli", "pair-score", split="train[:10000]")
# (anchor, positive, negative)
all_nli_triplet_train = load_dataset("sentence-transformers/all-nli", "triplet", split="train[:10000]")
# (sentence1, sentence2) + score
stsb_pair_score_train = load_dataset("sentence-transformers/stsb", split="train[:10000]")
# (anchor, positive)
quora_pair_train = load_dataset("sentence-transformers/quora-duplicates", "pair", split="train[:10000]")
# (query, answer)
natural_questions_train = load_dataset("sentence-transformers/natural-questions", split="train[:10000]")

# We can combine all datasets into a dictionary with dataset names to datasets
train_dataset = {
    "all-nli-pair": all_nli_pair_train,
    "all-nli-pair-class": all_nli_pair_class_train,
    "all-nli-pair-score": all_nli_pair_score_train,
    "all-nli-triplet": all_nli_triplet_train,
    "stsb": stsb_pair_score_train,
    "quora": quora_pair_train,
    "natural-questions": natural_questions_train,
}

# 3. Load several Datasets to evaluate with
# (anchor, positive, negative)
all_nli_triplet_dev = load_dataset("sentence-transformers/all-nli", "triplet", split="dev")
# (sentence1, sentence2, score)
stsb_pair_score_dev = load_dataset("sentence-transformers/stsb", split="validation")
# (anchor, positive)
quora_pair_dev = load_dataset("sentence-transformers/quora-duplicates", "pair", split="train[10000:11000]")
# (query, answer)
natural_questions_dev = load_dataset("sentence-transformers/natural-questions", split="train[10000:11000]")

# We can use a dictionary for the evaluation dataset too, but we don't have to. We could also just use
# no evaluation dataset, or one dataset.
eval_dataset = {
    "all-nli-triplet": all_nli_triplet_dev,
    "stsb": stsb_pair_score_dev,
    "quora": quora_pair_dev,
    "natural-questions": natural_questions_dev,
}

# 4. Load several loss functions to train with
# (anchor, positive), (anchor, positive, negative)
mnrl_loss = MultipleNegativesRankingLoss(model)
# (sentence_A, sentence_B) + class
softmax_loss = SoftmaxLoss(model, model.get_sentence_embedding_dimension(), 3)
# (sentence_A, sentence_B) + score
cosent_loss = CoSENTLoss(model)

# Create a mapping with dataset names to loss functions, so the trainer knows which loss to apply where.
# Note that you can also just use one loss if all of your training/evaluation datasets use the same loss
losses = {
    "all-nli-pair": mnrl_loss,
    "all-nli-pair-class": softmax_loss,
    "all-nli-pair-score": cosent_loss,
    "all-nli-triplet": mnrl_loss,
    "stsb": cosent_loss,
    "quora": mnrl_loss,
    "natural-questions": mnrl_loss,
}

# 5. Define a simple trainer, although it's recommended to use one with args & evaluators
trainer = SentenceTransformerTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=losses,
)
trainer.train()

# 6. save the trained model and optionally push it to the Hugging Face Hub
model.save_pretrained("bert-base-all-nli-stsb-quora-nq")
model.push_to_hub("bert-base-all-nli-stsb-quora-nq")

已弃用的训练

在 Sentence Transformers v3.0 发布之前,模型会使用 :meth:SentenceTransformer.fit <sentence_transformers.SentenceTransformer.fit> 方法和 :class:~torch.utils.data.DataLoader 的 :class:~sentence_transformers.readers.InputExample 进行训练,大致如下::

from sentence_transformers import SentenceTransformer, InputExample, losses
from torch.utils.data import DataLoader

# Define the model. Either from scratch of by loading a pre-trained model
model = SentenceTransformer("distilbert/distilbert-base-uncased")

# Define your train examples. You need more than just two examples...
train_examples = [
    InputExample(texts=["My first sentence", "My second sentence"], label=0.8),
    InputExample(texts=["Another pair", "Unrelated sentence"], label=0.3),
]

# Define your train dataset, the dataloader and the train loss
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)
train_loss = losses.CosineSimilarityLoss(model)

# Tune the model
model.fit(train_objectives=[(train_dataloader, train_loss)], epochs=1, warmup_steps=100)

自v3.0版本发布以来,使用 :meth:SentenceTransformer.fit <sentence_transformers.SentenceTransformer.fit> 仍然是可能的,但它将在后台初始化一个 :class:~sentence_transformers.trainer.SentenceTransformerTrainer。建议直接使用Trainer,因为通过 :class:~sentence_transformers.training_args.SentenceTransformerTrainingArguments 你可以有更多的控制,但依赖于 :meth:SentenceTransformer.fit <sentence_transformers.SentenceTransformer.fit> 的现有训练脚本应该仍然可以工作。

如果在更新后的 :meth:SentenceTransformer.fit <sentence_transformers.SentenceTransformer.fit> 中遇到问题,你也可以通过调用 :meth:SentenceTransformer.old_fit <sentence_transformers.SentenceTransformer.old_fit> 来获得完全旧的行为,但这种方法在未来将被完全弃用。

最佳基础嵌入模型

你的文本嵌入模型的质量取决于你选择的transformer模型。遗憾的是,我们不能仅从例如GLUE或SuperGLUE基准测试中的更好表现推断出该模型也会产生更好的表示。

为了测试变压器模型的适用性,我使用 training_nli_v2.py 脚本,并在 560k 个(锚点,正样本,负样本)三元组上训练 1 个周期,批量大小为 64。然后,我在来自不同领域的 14 个多样化的文本相似性任务(聚类、语义搜索、重复检测等)上进行评估。

在下表中,您可以找到不同模型的性能及其在此基准测试中的表现:

Model Performance (14 sentence similarity tasks)
microsoft/mpnet-base 60.99
nghuyong/ernie-2.0-en 60.73
microsoft/deberta-base 60.21
roberta-base 59.63
t5-base 59.21
bert-base-uncased 59.17
distilbert-base-uncased 59.03
nreimers/TinyBERT_L-6_H-768_v2 58.27
google/t5-v1_1-base 57.63
nreimers/MiniLMv2-L6-H768-distilled-from-BERT-Large 57.31
albert-base-v2 57.14
microsoft/MiniLM-L12-H384-uncased 56.79
microsoft/deberta-v3-base 54.46