交叉编码器¶
SentenceTransformers 还支持训练用于句子对评分和句子对分类任务的交叉编码器。有关交叉编码器的详细信息以及交叉编码器与双编码器之间的区别,请参阅交叉编码器。
示例¶
请参阅以下示例,了解如何训练交叉编码器:
training_stsbenchmark.py - 如何在 STS 基准数据集上训练用于语义文本相似性 (STS) 的示例。
training_quora_duplicate_questions.py - 训练交叉编码器以预测两个问题是否重复的示例。使用 Quora 重复问题作为训练数据集。
training_nli.py - 自然语言推理 (NLI) 任务的多标签分类任务示例。
训练交叉编码器¶
CrossEncoder
类是 Hugging Face AutoModelForSequenceClassification
的包装器,但包含一些方法使得训练和预测评分更加简便。保存的模型与 Hugging Face 完全兼容,也可以使用它们的类加载。
首先,您需要一些句子对数据。您可以有一个连续的分数,例如:
from sentence_transformers import InputExample
train_samples = [
InputExample(texts=["sentence1", "sentence2"], label=0.3),
InputExample(texts=["Another", "pair"], label=0.8),
]
或者,您有如 training_nli.py 示例中的不同类别:
from sentence_transformers import InputExample
label2int = {"contradiction": 0, "entailment": 1, "neutral": 2}
train_samples = [
InputExample(texts=["sentence1", "sentence2"], label=label2int["neutral"]),
InputExample(texts=["Another", "pair"], label=label2int["entailment"]),
]
然后,您定义基础模型和标签数量。您可以使用任何与 AutoModel 兼容的 Hugging Face 预训练模型:
model = CrossEncoder('distilroberta-base', num_labels=1)
对于二元任务和具有连续分数的任务(如 STS),我们将 num_labels 设置为 1。对于分类任务,我们将其设置为我们拥有的标签数量。
我们通过调用 model.fit()
来开始训练:
model.fit(
train_dataloader=train_dataloader,
evaluator=evaluator,
epochs=num_epochs,
warmup_steps=warmup_steps,
output_path=model_save_path,
)