sentence_transformers.losses.MultipleNegativesSymmetricRankingLoss 源代码

from __future__ import annotations

from typing import Any, Iterable

import torch
from torch import Tensor, nn

from sentence_transformers import util
from sentence_transformers.SentenceTransformer import SentenceTransformer


[文档] class MultipleNegativesSymmetricRankingLoss(nn.Module): def __init__(self, model: SentenceTransformer, scale: float = 20.0, similarity_fct=util.cos_sim) -> None: """ Given a list of (anchor, positive) pairs, this loss sums the following two losses: 1. Forward loss: Given an anchor, find the sample with the highest similarity out of all positives in the batch. This is equivalent to :class:`MultipleNegativesRankingLoss`. 2. Backward loss: Given a positive, find the sample with the highest similarity out of all anchors in the batch. For example with question-answer pairs, :class:`MultipleNegativesRankingLoss` just computes the loss to find the answer given a question, but :class:`MultipleNegativesSymmetricRankingLoss` additionally computes the loss to find the question given an answer. Note: If you pass triplets, the negative entry will be ignored. A anchor is just searched for the positive. Args: model: SentenceTransformer model scale: Output of similarity function is multiplied by scale value similarity_fct: similarity function between sentence embeddings. By default, cos_sim. Can also be set to dot product (and then set scale to 1) Requirements: 1. (anchor, positive) pairs Inputs: +---------------------------------------+--------+ | Texts | Labels | +=======================================+========+ | (anchor, positive) pairs | none | +---------------------------------------+--------+ Recommendations: - Use ``BatchSamplers.NO_DUPLICATES`` (:class:`docs <sentence_transformers.training_args.BatchSamplers>`) to ensure that no in-batch negatives are duplicates of the anchor or positive samples. Relations: - Like :class:`MultipleNegativesRankingLoss`, but with an additional loss term. - :class:`CachedMultipleNegativesSymmetricRankingLoss` is equivalent to this loss, but it uses caching that allows for much higher batch sizes (and thus better performance) without extra memory usage. However, it is slightly slower. Example: :: from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, losses from datasets import Dataset model = SentenceTransformer("microsoft/mpnet-base") train_dataset = Dataset.from_dict({ "anchor": ["It's nice weather outside today.", "He drove to work."], "positive": ["It's so sunny.", "He took the car to the office."], }) loss = losses.MultipleNegativesSymmetricRankingLoss(model) trainer = SentenceTransformerTrainer( model=model, train_dataset=train_dataset, loss=loss, ) trainer.train() """ super().__init__() self.model = model self.scale = scale self.similarity_fct = similarity_fct self.cross_entropy_loss = nn.CrossEntropyLoss() def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor) -> Tensor: reps = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features] anchor = reps[0] candidates = torch.cat(reps[1:]) scores = self.similarity_fct(anchor, candidates) * self.scale labels = torch.tensor( range(len(scores)), dtype=torch.long, device=scores.device ) # Example a[i] should match with b[i] anchor_positive_scores = scores[:, 0 : len(reps[1])] forward_loss = self.cross_entropy_loss(scores, labels) backward_loss = self.cross_entropy_loss(anchor_positive_scores.transpose(0, 1), labels) return (forward_loss + backward_loss) / 2 def get_config_dict(self) -> dict[str, Any]: return {"scale": self.scale, "similarity_fct": self.similarity_fct.__name__}