sentence_transformers.losses.TripletLoss 源代码

from __future__ import annotations

from enum import Enum
from typing import Any, Iterable

import torch.nn.functional as F
from torch import Tensor, nn

from sentence_transformers.SentenceTransformer import SentenceTransformer


class TripletDistanceMetric(Enum):
    """The metric for the triplet loss"""

    COSINE = lambda x, y: 1 - F.cosine_similarity(x, y)
    EUCLIDEAN = lambda x, y: F.pairwise_distance(x, y, p=2)
    MANHATTAN = lambda x, y: F.pairwise_distance(x, y, p=1)


[文档] class TripletLoss(nn.Module): def __init__( self, model: SentenceTransformer, distance_metric=TripletDistanceMetric.EUCLIDEAN, triplet_margin: float = 5 ) -> None: """ This class implements triplet loss. Given a triplet of (anchor, positive, negative), the loss minimizes the distance between anchor and positive while it maximizes the distance between anchor and negative. It compute the following loss function: ``loss = max(||anchor - positive|| - ||anchor - negative|| + margin, 0)``. Margin is an important hyperparameter and needs to be tuned respectively. Args: model: SentenceTransformerModel distance_metric: Function to compute distance between two embeddings. The class TripletDistanceMetric contains common distance metrices that can be used. triplet_margin: The negative should be at least this much further away from the anchor than the positive. References: - For further details, see: https://en.wikipedia.org/wiki/Triplet_loss Requirements: 1. (anchor, positive, negative) triplets Inputs: +---------------------------------------+--------+ | Texts | Labels | +=======================================+========+ | (anchor, positive, negative) triplets | none | +---------------------------------------+--------+ Example: :: from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, losses from datasets import Dataset model = SentenceTransformer("microsoft/mpnet-base") train_dataset = Dataset.from_dict({ "anchor": ["It's nice weather outside today.", "He drove to work."], "positive": ["It's so sunny.", "He took the car to the office."], "negative": ["It's quite rainy, sadly.", "She walked to the store."], }) loss = losses.TripletLoss(model=model) trainer = SentenceTransformerTrainer( model=model, train_dataset=train_dataset, loss=loss, ) trainer.train() """ super().__init__() self.model = model self.distance_metric = distance_metric self.triplet_margin = triplet_margin def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor) -> Tensor: reps = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features] rep_anchor, rep_pos, rep_neg = reps distance_pos = self.distance_metric(rep_anchor, rep_pos) distance_neg = self.distance_metric(rep_anchor, rep_neg) losses = F.relu(distance_pos - distance_neg + self.triplet_margin) return losses.mean() def get_config_dict(self) -> dict[str, Any]: distance_metric_name = self.distance_metric.__name__ for name, value in vars(TripletDistanceMetric).items(): if value == self.distance_metric: distance_metric_name = f"TripletDistanceMetric.{name}" break return {"distance_metric": distance_metric_name, "triplet_margin": self.triplet_margin} @property def citation(self) -> str: return """ @misc{hermans2017defense, title={In Defense of the Triplet Loss for Person Re-Identification}, author={Alexander Hermans and Lucas Beyer and Bastian Leibe}, year={2017}, eprint={1703.07737}, archivePrefix={arXiv}, primaryClass={cs.CV} } """