sentence_transformers.losses.SoftmaxLoss 源代码

from __future__ import annotations

import logging
from typing import Callable, Iterable

import torch
import transformers
from packaging import version
from torch import Tensor, nn

from sentence_transformers.SentenceTransformer import SentenceTransformer

logger = logging.getLogger(__name__)


[文档] class SoftmaxLoss(nn.Module): def __init__( self, model: SentenceTransformer, sentence_embedding_dimension: int, num_labels: int, concatenation_sent_rep: bool = True, concatenation_sent_difference: bool = True, concatenation_sent_multiplication: bool = False, loss_fct: Callable = nn.CrossEntropyLoss(), ) -> None: """ This loss was used in our SBERT publication (https://arxiv.org/abs/1908.10084) to train the SentenceTransformer model on NLI data. It adds a softmax classifier on top of the output of two transformer networks. :class:`MultipleNegativesRankingLoss` is an alternative loss function that often yields better results, as per https://arxiv.org/abs/2004.09813. Args: model (SentenceTransformer): The SentenceTransformer model. sentence_embedding_dimension (int): The dimension of the sentence embeddings. num_labels (int): The number of different labels. concatenation_sent_rep (bool): Whether to concatenate vectors u,v for the softmax classifier. Defaults to True. concatenation_sent_difference (bool): Whether to add abs(u-v) for the softmax classifier. Defaults to True. concatenation_sent_multiplication (bool): Whether to add u*v for the softmax classifier. Defaults to False. loss_fct (Callable): Custom pytorch loss function. If not set, uses nn.CrossEntropyLoss(). Defaults to nn.CrossEntropyLoss(). References: - Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks: https://arxiv.org/abs/1908.10084 - `Training Examples > Natural Language Inference <../../examples/training/nli/README.html>`_ Requirements: 1. sentence pairs with a class label Inputs: +---------------------------------------+--------+ | Texts | Labels | +=======================================+========+ | (sentence_A, sentence_B) pairs | class | +---------------------------------------+--------+ Example: :: from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, losses from datasets import Dataset model = SentenceTransformer("microsoft/mpnet-base") train_dataset = Dataset.from_dict({ "sentence1": [ "A person on a horse jumps over a broken down airplane.", "A person on a horse jumps over a broken down airplane.", "A person on a horse jumps over a broken down airplane.", "Children smiling and waving at camera", ], "sentence2": [ "A person is training his horse for a competition.", "A person is at a diner, ordering an omelette.", "A person is outdoors, on a horse.", "There are children present.", ], "label": [1, 2, 0, 0], }) loss = losses.SoftmaxLoss(model, model.get_sentence_embedding_dimension(), num_labels=3) trainer = SentenceTransformerTrainer( model=model, train_dataset=train_dataset, loss=loss, ) trainer.train() """ super().__init__() self.model = model self.num_labels = num_labels self.concatenation_sent_rep = concatenation_sent_rep self.concatenation_sent_difference = concatenation_sent_difference self.concatenation_sent_multiplication = concatenation_sent_multiplication num_vectors_concatenated = 0 if concatenation_sent_rep: num_vectors_concatenated += 2 if concatenation_sent_difference: num_vectors_concatenated += 1 if concatenation_sent_multiplication: num_vectors_concatenated += 1 logger.info(f"Softmax loss: #Vectors concatenated: {num_vectors_concatenated}") self.classifier = nn.Linear( num_vectors_concatenated * sentence_embedding_dimension, num_labels, device=model.device ) self.loss_fct = loss_fct if version.parse(transformers.__version__) < version.parse("4.43.0"): logger.warning( "SoftmaxLoss requires transformers >= 4.43.0 to work correctly. " "Otherwise, the classifier layer that maps embeddings to the labels cannot be updated. " "Consider updating transformers with `pip install transformers>=4.43.0`." ) def forward( self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor ) -> Tensor | tuple[Tensor, Tensor]: reps = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features] rep_a, rep_b = reps vectors_concat = [] if self.concatenation_sent_rep: vectors_concat.append(rep_a) vectors_concat.append(rep_b) if self.concatenation_sent_difference: vectors_concat.append(torch.abs(rep_a - rep_b)) if self.concatenation_sent_multiplication: vectors_concat.append(rep_a * rep_b) features = torch.cat(vectors_concat, 1) output = self.classifier(features) if labels is not None: loss = self.loss_fct(output, labels.view(-1)) return loss else: return reps, output @property def citation(self) -> str: return """ @inproceedings{reimers-2019-sentence-bert, title = "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks", author = "Reimers, Nils and Gurevych, Iryna", booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing", month = "11", year = "2019", publisher = "Association for Computational Linguistics", url = "https://arxiv.org/abs/1908.10084", } """