sentence_transformers.losses.ContrastiveLoss 源代码

from __future__ import annotations

from enum import Enum
from typing import Any, Iterable

import torch.nn.functional as F
from torch import Tensor, nn

from sentence_transformers.SentenceTransformer import SentenceTransformer


class SiameseDistanceMetric(Enum):
    """The metric for the contrastive loss"""

    EUCLIDEAN = lambda x, y: F.pairwise_distance(x, y, p=2)
    MANHATTAN = lambda x, y: F.pairwise_distance(x, y, p=1)
    COSINE_DISTANCE = lambda x, y: 1 - F.cosine_similarity(x, y)


[文档] class ContrastiveLoss(nn.Module): def __init__( self, model: SentenceTransformer, distance_metric=SiameseDistanceMetric.COSINE_DISTANCE, margin: float = 0.5, size_average: bool = True, ) -> None: """ Contrastive loss. Expects as input two texts and a label of either 0 or 1. If the label == 1, then the distance between the two embeddings is reduced. If the label == 0, then the distance between the embeddings is increased. Args: model: SentenceTransformer model distance_metric: Function that returns a distance between two embeddings. The class SiameseDistanceMetric contains pre-defined metrices that can be used margin: Negative samples (label == 0) should have a distance of at least the margin value. size_average: Average by the size of the mini-batch. References: * Further information: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf * `Training Examples > Quora Duplicate Questions <../../examples/training/quora_duplicate_questions/README.html>`_ Requirements: 1. (anchor, positive/negative) pairs Inputs: +-----------------------------------------------+------------------------------+ | Texts | Labels | +===============================================+==============================+ | (anchor, positive/negative) pairs | 1 if positive, 0 if negative | +-----------------------------------------------+------------------------------+ Relations: - :class:`OnlineContrastiveLoss` is similar, but uses hard positive and hard negative pairs. It often yields better results. Example: :: from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, losses from datasets import Dataset model = SentenceTransformer("microsoft/mpnet-base") train_dataset = Dataset.from_dict({ "sentence1": ["It's nice weather outside today.", "He drove to work."], "sentence2": ["It's so sunny.", "She walked to the store."], "label": [1, 0], }) loss = losses.ContrastiveLoss(model) trainer = SentenceTransformerTrainer( model=model, train_dataset=train_dataset, loss=loss, ) trainer.train() """ super().__init__() self.distance_metric = distance_metric self.margin = margin self.model = model self.size_average = size_average def get_config_dict(self) -> dict[str, Any]: distance_metric_name = self.distance_metric.__name__ for name, value in vars(SiameseDistanceMetric).items(): if value == self.distance_metric: distance_metric_name = f"SiameseDistanceMetric.{name}" break return {"distance_metric": distance_metric_name, "margin": self.margin, "size_average": self.size_average} def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor) -> Tensor: reps = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features] assert len(reps) == 2 rep_anchor, rep_other = reps distances = self.distance_metric(rep_anchor, rep_other) losses = 0.5 * ( labels.float() * distances.pow(2) + (1 - labels).float() * F.relu(self.margin - distances).pow(2) ) return losses.mean() if self.size_average else losses.sum() @property def citation(self) -> str: return """ @inproceedings{hadsell2006dimensionality, author={Hadsell, R. and Chopra, S. and LeCun, Y.}, booktitle={2006 IEEE Computer Society Conference on Computer Vision and Pattern Recognition (CVPR'06)}, title={Dimensionality Reduction by Learning an Invariant Mapping}, year={2006}, volume={2}, number={}, pages={1735-1742}, doi={10.1109/CVPR.2006.100} } """