sentence_transformers.evaluation.TranslationEvaluator 源代码

from __future__ import annotations

import csv
import logging
import os
from contextlib import nullcontext
from typing import TYPE_CHECKING

import numpy as np
import torch

from sentence_transformers.evaluation.SentenceEvaluator import SentenceEvaluator
from sentence_transformers.util import pytorch_cos_sim

if TYPE_CHECKING:
    from sentence_transformers.SentenceTransformer import SentenceTransformer

logger = logging.getLogger(__name__)


[文档] class TranslationEvaluator(SentenceEvaluator): """ Given two sets of sentences in different languages, e.g. (en_1, en_2, en_3...) and (fr_1, fr_2, fr_3, ...), and assuming that fr_i is the translation of en_i. Checks if vec(en_i) has the highest similarity to vec(fr_i). Computes the accuracy in both directions Example: :: from sentence_transformers import SentenceTransformer from sentence_transformers.evaluation import TranslationEvaluator from datasets import load_dataset # Load a model model = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2') # Load a parallel sentences dataset dataset = load_dataset("sentence-transformers/parallel-sentences-news-commentary", "en-nl", split="train[:1000]") # Initialize the TranslationEvaluator using the same texts from two languages translation_evaluator = TranslationEvaluator( source_sentences=dataset["english"], target_sentences=dataset["non_english"], name="news-commentary-en-nl", ) results = translation_evaluator(model) ''' Evaluating translation matching Accuracy of the model on the news-commentary-en-nl dataset: Accuracy src2trg: 90.80 Accuracy trg2src: 90.40 ''' print(translation_evaluator.primary_metric) # => "news-commentary-en-nl_mean_accuracy" print(results[translation_evaluator.primary_metric]) # => 0.906 """ def __init__( self, source_sentences: list[str], target_sentences: list[str], show_progress_bar: bool = False, batch_size: int = 16, name: str = "", print_wrong_matches: bool = False, write_csv: bool = True, truncate_dim: int | None = None, ): """ Constructs an evaluator based for the dataset The labels need to indicate the similarity between the sentences. Args: source_sentences (List[str]): List of sentences in the source language. target_sentences (List[str]): List of sentences in the target language. show_progress_bar (bool): Whether to show a progress bar when computing embeddings. Defaults to False. batch_size (int): The batch size to compute sentence embeddings. Defaults to 16. name (str): The name of the evaluator. Defaults to an empty string. print_wrong_matches (bool): Whether to print incorrect matches. Defaults to False. write_csv (bool): Whether to write the evaluation results to a CSV file. Defaults to True. truncate_dim (int, optional): The dimension to truncate sentence embeddings to. If None, the model's current truncation dimension will be used. Defaults to None. """ super().__init__() self.source_sentences = source_sentences self.target_sentences = target_sentences self.name = name self.batch_size = batch_size self.show_progress_bar = show_progress_bar self.print_wrong_matches = print_wrong_matches self.truncate_dim = truncate_dim assert len(self.source_sentences) == len(self.target_sentences) if name: name = "_" + name self.csv_file = "translation_evaluation" + name + "_results.csv" self.csv_headers = ["epoch", "steps", "src2trg", "trg2src"] self.write_csv = write_csv self.primary_metric = "mean_accuracy" def __call__( self, model: SentenceTransformer, output_path: str = None, epoch: int = -1, steps: int = -1 ) -> dict[str, float]: if epoch != -1: if steps == -1: out_txt = f" after epoch {epoch}" else: out_txt = f" in epoch {epoch} after {steps} steps" else: out_txt = "" if self.truncate_dim is not None: out_txt += f" (truncated to {self.truncate_dim})" logger.info(f"Evaluating translation matching Accuracy of the model on the {self.name} dataset{out_txt}:") with nullcontext() if self.truncate_dim is None else model.truncate_sentence_embeddings(self.truncate_dim): embeddings1 = torch.stack( model.encode( self.source_sentences, show_progress_bar=self.show_progress_bar, batch_size=self.batch_size, convert_to_numpy=False, ) ) embeddings2 = torch.stack( model.encode( self.target_sentences, show_progress_bar=self.show_progress_bar, batch_size=self.batch_size, convert_to_numpy=False, ) ) cos_sims = pytorch_cos_sim(embeddings1, embeddings2).detach().cpu().numpy() correct_src2trg = 0 correct_trg2src = 0 for i in range(len(cos_sims)): max_idx = np.argmax(cos_sims[i]) if i == max_idx: correct_src2trg += 1 elif self.print_wrong_matches: print("\nIncorrect : Source", i, "is most similar to target", max_idx, "instead of target", i) print("Source :", self.source_sentences[i]) print("Pred Target:", self.target_sentences[max_idx], f"(Score: {cos_sims[i][max_idx]:.4f})") print("True Target:", self.target_sentences[i], f"(Score: {cos_sims[i][i]:.4f})") results = enumerate(cos_sims[i]) results = sorted(results, key=lambda x: x[1], reverse=True) for idx, score in results[:5]: print("\t", idx, f"(Score: {score:.4f})", self.target_sentences[idx]) cos_sims = cos_sims.T for i in range(len(cos_sims)): max_idx = np.argmax(cos_sims[i]) if i == max_idx: correct_trg2src += 1 acc_src2trg = correct_src2trg / len(cos_sims) acc_trg2src = correct_trg2src / len(cos_sims) logger.info(f"Accuracy src2trg: {acc_src2trg * 100:.2f}") logger.info(f"Accuracy trg2src: {acc_trg2src * 100:.2f}") if output_path is not None and self.write_csv: csv_path = os.path.join(output_path, self.csv_file) output_file_exists = os.path.isfile(csv_path) with open(csv_path, newline="", mode="a" if output_file_exists else "w", encoding="utf-8") as f: writer = csv.writer(f) if not output_file_exists: writer.writerow(self.csv_headers) writer.writerow([epoch, steps, acc_src2trg, acc_trg2src]) metrics = { "src2trg_accuracy": acc_src2trg, "trg2src_accuracy": acc_trg2src, "mean_accuracy": (acc_src2trg + acc_trg2src) / 2, } metrics = self.prefix_name_to_metrics(metrics, self.name) self.store_metrics_in_model_card_data(model, metrics) return metrics