from __future__ import annotations
from typing import Any, Iterable
import torch
from torch import Tensor, nn
from sentence_transformers import util
from sentence_transformers.SentenceTransformer import SentenceTransformer
[文档]
class MultipleNegativesRankingLoss(nn.Module):
def __init__(self, model: SentenceTransformer, scale: float = 20.0, similarity_fct=util.cos_sim) -> None:
"""
This loss expects as input a batch consisting of sentence pairs ``(a_1, p_1), (a_2, p_2)..., (a_n, p_n)``
where we assume that ``(a_i, p_i)`` are a positive pair and ``(a_i, p_j)`` for ``i != j`` a negative pair.
For each ``a_i``, it uses all other ``p_j`` as negative samples, i.e., for ``a_i``, we have 1 positive example
(``p_i``) and ``n-1`` negative examples (``p_j``). It then minimizes the negative log-likehood for softmax
normalized scores.
This loss function works great to train embeddings for retrieval setups where you have positive pairs
(e.g. (query, relevant_doc)) as it will sample in each batch ``n-1`` negative docs randomly.
The performance usually increases with increasing batch sizes.
You can also provide one or multiple hard negatives per anchor-positive pair by structuring the data like this:
``(a_1, p_1, n_1), (a_2, p_2, n_2)``. Then, ``n_1`` is a hard negative for ``(a_1, p_1)``. The loss will use for
the pair ``(a_i, p_i)`` all ``p_j`` for ``j != i`` and all ``n_j`` as negatives.
Args:
model: SentenceTransformer model
scale: Output of similarity function is multiplied by scale
value
similarity_fct: similarity function between sentence
embeddings. By default, cos_sim. Can also be set to dot
product (and then set scale to 1)
References:
- Efficient Natural Language Response Suggestion for Smart Reply, Section 4.4: https://arxiv.org/pdf/1705.00652.pdf
- `Training Examples > Natural Language Inference <../../examples/training/nli/README.html>`_
- `Training Examples > Paraphrase Data <../../examples/training/paraphrases/README.html>`_
- `Training Examples > Quora Duplicate Questions <../../examples/training/quora_duplicate_questions/README.html>`_
- `Training Examples > MS MARCO <../../examples/training/ms_marco/README.html>`_
- `Unsupervised Learning > SimCSE <../../examples/unsupervised_learning/SimCSE/README.html>`_
- `Unsupervised Learning > GenQ <../../examples/unsupervised_learning/query_generation/README.html>`_
Requirements:
1. (anchor, positive) pairs or (anchor, positive, negative) triplets
Inputs:
+---------------------------------------+--------+
| Texts | Labels |
+=======================================+========+
| (anchor, positive) pairs | none |
+---------------------------------------+--------+
| (anchor, positive, negative) triplets | none |
+---------------------------------------+--------+
Recommendations:
- Use ``BatchSamplers.NO_DUPLICATES`` (:class:`docs <sentence_transformers.training_args.BatchSamplers>`) to
ensure that no in-batch negatives are duplicates of the anchor or positive samples.
Relations:
- :class:`CachedMultipleNegativesRankingLoss` is equivalent to this loss, but it uses caching that allows for
much higher batch sizes (and thus better performance) without extra memory usage. However, it is slightly
slower.
- :class:`MultipleNegativesSymmetricRankingLoss` is equivalent to this loss, but with an additional loss term.
- :class:`GISTEmbedLoss` is equivalent to this loss, but uses a guide model to guide the in-batch negative
sample selection. `GISTEmbedLoss` yields a stronger training signal at the cost of some training overhead.
Example:
::
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, losses
from datasets import Dataset
model = SentenceTransformer("microsoft/mpnet-base")
train_dataset = Dataset.from_dict({
"anchor": ["It's nice weather outside today.", "He drove to work."],
"positive": ["It's so sunny.", "He took the car to the office."],
})
loss = losses.MultipleNegativesRankingLoss(model)
trainer = SentenceTransformerTrainer(
model=model,
train_dataset=train_dataset,
loss=loss,
)
trainer.train()
"""
super().__init__()
self.model = model
self.scale = scale
self.similarity_fct = similarity_fct
self.cross_entropy_loss = nn.CrossEntropyLoss()
def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor) -> Tensor:
reps = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features]
embeddings_a = reps[0]
embeddings_b = torch.cat(reps[1:])
scores = self.similarity_fct(embeddings_a, embeddings_b) * self.scale
# Example a[i] should match with b[i]
range_labels = torch.arange(0, scores.size(0), device=scores.device)
return self.cross_entropy_loss(scores, range_labels)
def get_config_dict(self) -> dict[str, Any]:
return {"scale": self.scale, "similarity_fct": self.similarity_fct.__name__}
@property
def citation(self) -> str:
return """
@misc{henderson2017efficient,
title={Efficient Natural Language Response Suggestion for Smart Reply},
author={Matthew Henderson and Rami Al-Rfou and Brian Strope and Yun-hsuan Sung and Laszlo Lukacs and Ruiqi Guo and Sanjiv Kumar and Balint Miklos and Ray Kurzweil},
year={2017},
eprint={1705.00652},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
"""