交叉编码器

SentenceTransformers 还支持训练用于句子对评分和句子对分类任务的交叉编码器。有关交叉编码器的详细信息以及交叉编码器与双编码器之间的区别,请参阅交叉编码器

示例

请参阅以下示例,了解如何训练交叉编码器:

训练交叉编码器

CrossEncoder 类是 Hugging Face AutoModelForSequenceClassification 的包装器,但包含一些方法使得训练和预测评分更加简便。保存的模型与 Hugging Face 完全兼容,也可以使用它们的类加载。

首先,您需要一些句子对数据。您可以有一个连续的分数,例如:

from sentence_transformers import InputExample

train_samples = [
    InputExample(texts=["sentence1", "sentence2"], label=0.3),
    InputExample(texts=["Another", "pair"], label=0.8),
]

或者,您有如 training_nli.py 示例中的不同类别:

from sentence_transformers import InputExample

label2int = {"contradiction": 0, "entailment": 1, "neutral": 2}
train_samples = [
    InputExample(texts=["sentence1", "sentence2"], label=label2int["neutral"]),
    InputExample(texts=["Another", "pair"], label=label2int["entailment"]),
]

然后,您定义基础模型和标签数量。您可以使用任何与 AutoModel 兼容的 Hugging Face 预训练模型

model = CrossEncoder('distilroberta-base', num_labels=1)

对于二元任务和具有连续分数的任务(如 STS),我们将 num_labels 设置为 1。对于分类任务,我们将其设置为我们拥有的标签数量。

我们通过调用 model.fit() 来开始训练:

model.fit(
    train_dataloader=train_dataloader,
    evaluator=evaluator,
    epochs=num_epochs,
    warmup_steps=warmup_steps,
    output_path=model_save_path,
)