from __future__ import annotations
from typing import Any, Iterable
import torch
from torch import Tensor, nn
from sentence_transformers.models import Transformer
from sentence_transformers.SentenceTransformer import SentenceTransformer
[文档]
class GISTEmbedLoss(nn.Module):
def __init__(
self,
model: SentenceTransformer,
guide: SentenceTransformer,
temperature: float = 0.01,
) -> None:
"""
This loss is used to train a SentenceTransformer model using the GISTEmbed algorithm.
It takes a model and a guide model as input, and uses the guide model to guide the
in-batch negative sample selection. The cosine similarity is used to compute the loss
and the temperature parameter is used to scale the cosine similarities.
Args:
model: SentenceTransformer model based on a `transformers`
model.
guide: SentenceTransformer model to guide the in-batch
negative sample selection.
temperature: Temperature parameter to scale the cosine
similarities.
References:
- For further details, see: https://arxiv.org/abs/2402.16829
Requirements:
1. (anchor, positive, negative) triplets
2. (anchor, positive) pairs
Inputs:
+---------------------------------------+--------+
| Texts | Labels |
+=======================================+========+
| (anchor, positive, negative) triplets | none |
+---------------------------------------+--------+
| (anchor, positive) pairs | none |
+---------------------------------------+--------+
Recommendations:
- Use ``BatchSamplers.NO_DUPLICATES`` (:class:`docs <sentence_transformers.training_args.BatchSamplers>`) to
ensure that no in-batch negatives are duplicates of the anchor or positive samples.
Relations:
- :class:`MultipleNegativesRankingLoss` is similar to this loss, but it does not use
a guide model to guide the in-batch negative sample selection. `GISTEmbedLoss` yields
a stronger training signal at the cost of some training overhead.
Example:
::
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, losses
from datasets import Dataset
model = SentenceTransformer("microsoft/mpnet-base")
guide = SentenceTransformer("all-MiniLM-L6-v2")
train_dataset = Dataset.from_dict({
"anchor": ["It's nice weather outside today.", "He drove to work."],
"positive": ["It's so sunny.", "He took the car to the office."],
})
loss = losses.GISTEmbedLoss(model, guide)
trainer = SentenceTransformerTrainer(
model=model,
train_dataset=train_dataset,
loss=loss,
)
trainer.train()
"""
super().__init__()
self.model = model
self.guide = guide
self.temperature = temperature
self.similarity_fct = nn.CosineSimilarity(dim=-1)
if not isinstance(model[0], Transformer) or not isinstance(guide[0], Transformer):
raise ValueError(
"Both the training model and the guiding model must be based on the `transformers` architecture."
)
self.must_retokenize = (
model.tokenizer.vocab != guide.tokenizer.vocab or guide.max_seq_length < model.max_seq_length
)
if self.must_retokenize:
self.tokenizer = self.model.tokenizer
def sim_matrix(self, embed1: Tensor, embed2: Tensor) -> Tensor:
return self.similarity_fct(embed1.unsqueeze(1), embed2.unsqueeze(0))
def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor) -> Tensor:
embeddings = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features]
with torch.no_grad():
if self.must_retokenize:
decoded = [
self.tokenizer.batch_decode(sentence_feature["input_ids"], skip_special_tokens=True)
for sentence_feature in sentence_features
]
sentence_features = [self.guide.tokenize(sentences) for sentences in decoded]
sentence_features = [
{key: value.to(self.guide.device) for key, value in sentence_feature.items()}
for sentence_feature in sentence_features
]
guide_embeddings = [
self.guide(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features
]
negative = None
negative_guide = None
if len(embeddings) == 2:
anchor, positive = embeddings
anchor_guide, positive_guide = guide_embeddings
elif len(embeddings) == 3:
anchor, positive, negative = embeddings
anchor_guide, positive_guide, negative_guide = guide_embeddings
else:
raise ValueError(f"Expected 2 or 3 embeddings, got {len(embeddings)}")
# Compute the model's similarities
ap_sim = self.sim_matrix(anchor, positive)
aa_sim = self.sim_matrix(anchor, anchor)
pp_sim = self.sim_matrix(positive, positive)
# Let's compute the similarity matrices for the combinations of anchor and positive samples.
guided_ap_sim = self.sim_matrix(anchor_guide, positive_guide)
guided_aa_sim = self.sim_matrix(anchor_guide, anchor_guide)
guided_pp_sim = self.sim_matrix(positive_guide, positive_guide)
# Define the anchor threshold
guided_sim = guided_ap_sim.diagonal().view(-1, 1)
# Find which samples cannot be used as negatives because they are
# more similar to the query than the assigned positive as deemed by the guide model.
# For these samples, we mask them with -inf to basically ignore their contribution to
# the loss.
ap_sim[guided_ap_sim > guided_sim] = -torch.inf
aa_sim[guided_aa_sim > guided_sim] = -torch.inf
pp_sim[guided_pp_sim > guided_sim] = -torch.inf
scores = [ap_sim, aa_sim, pp_sim]
# Handle the case where we have a negative sample
if negative is not None:
an_sim = self.sim_matrix(anchor, negative)
guided_an_sim = self.sim_matrix(anchor_guide, negative_guide)
an_sim[guided_an_sim > guided_sim] = -torch.inf
scores.append(an_sim)
scores = torch.cat(scores, dim=1) / self.temperature
# NOTE: We use arange here since the ap_sim matrix contains the anchor-positive
# similarities along the diagonal.
labels = torch.arange(scores.size(0)).long().to(scores.device)
return nn.CrossEntropyLoss()(scores, labels)
def get_config_dict(self) -> dict[str, Any]:
return {
"guide": self.guide,
"temperature": self.temperature,
}
@property
def citation(self) -> str:
return """
@misc{solatorio2024gistembed,
title={GISTEmbed: Guided In-sample Selection of Training Negatives for Text Embedding Fine-tuning},
author={Aivin V. Solatorio},
year={2024},
eprint={2402.16829},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
"""