sentence_transformers.evaluation.ParaphraseMiningEvaluator 源代码

from __future__ import annotations

import csv
import logging
import os
from collections import defaultdict
from contextlib import nullcontext
from typing import TYPE_CHECKING

from sentence_transformers.evaluation.SentenceEvaluator import SentenceEvaluator
from sentence_transformers.util import paraphrase_mining

if TYPE_CHECKING:
    from sentence_transformers.SentenceTransformer import SentenceTransformer

logger = logging.getLogger(__name__)


[文档] class ParaphraseMiningEvaluator(SentenceEvaluator): """ Given a large set of sentences, this evaluator performs paraphrase (duplicate) mining and identifies the pairs with the highest similarity. It compare the extracted paraphrase pairs with a set of gold labels and computes the F1 score. Example: :: from datasets import load_dataset from sentence_transformers.SentenceTransformer import SentenceTransformer from sentence_transformers.evaluation import ParaphraseMiningEvaluator # Load a model model = SentenceTransformer('all-mpnet-base-v2') # Load the Quora Duplicates Mining dataset questions_dataset = load_dataset("sentence-transformers/quora-duplicates-mining", "questions", split="dev") duplicates_dataset = load_dataset("sentence-transformers/quora-duplicates-mining", "duplicates", split="dev") # Create a mapping from qid to question & a list of duplicates (qid1, qid2) qid_to_questions = dict(zip(questions_dataset["qid"], questions_dataset["question"])) duplicates = list(zip(duplicates_dataset["qid1"], duplicates_dataset["qid2"])) # Initialize the paraphrase mining evaluator paraphrase_mining_evaluator = ParaphraseMiningEvaluator( sentences_map=qid_to_questions, duplicates_list=duplicates, name="quora-duplicates-dev", ) results = paraphrase_mining_evaluator(model) ''' Paraphrase Mining Evaluation of the model on the quora-duplicates-dev dataset: Number of candidate pairs: 250564 Average Precision: 56.51 Optimal threshold: 0.8325 Precision: 52.76 Recall: 59.19 F1: 55.79 ''' print(paraphrase_mining_evaluator.primary_metric) # => "quora-duplicates-dev_average_precision" print(results[paraphrase_mining_evaluator.primary_metric]) # => 0.5650940787776353 """ def __init__( self, sentences_map: dict[str, str], duplicates_list: list[tuple[str, str]] = None, duplicates_dict: dict[str, dict[str, bool]] = None, add_transitive_closure: bool = False, query_chunk_size: int = 5000, corpus_chunk_size: int = 100000, max_pairs: int = 500000, top_k: int = 100, show_progress_bar: bool = False, batch_size: int = 16, name: str = "", write_csv: bool = True, truncate_dim: int | None = None, ): """ Initializes the ParaphraseMiningEvaluator. Args: sentences_map (Dict[str, str]): A dictionary that maps sentence-ids to sentences. For example, sentences_map[id] => sentence. duplicates_list (List[Tuple[str, str]], optional): A list with id pairs [(id1, id2), (id1, id5)] that identifies the duplicates / paraphrases in the sentences_map. Defaults to None. duplicates_dict (Dict[str, Dict[str, bool]], optional): A default dictionary mapping [id1][id2] to true if id1 and id2 are duplicates. Must be symmetric, i.e., if [id1][id2] => True, then [id2][id1] => True. Defaults to None. add_transitive_closure (bool, optional): If true, it adds a transitive closure, i.e. if dup[a][b] and dup[b][c], then dup[a][c]. Defaults to False. query_chunk_size (int, optional): To identify the paraphrases, the cosine-similarity between all sentence-pairs will be computed. As this might require a lot of memory, we perform a batched computation. query_chunk_size sentences will be compared against up to corpus_chunk_size sentences. In the default setting, 5000 sentences will be grouped together and compared up-to against 100k other sentences. Defaults to 5000. corpus_chunk_size (int, optional): The corpus will be batched, to reduce the memory requirement. Defaults to 100000. max_pairs (int, optional): We will only extract up to max_pairs potential paraphrase candidates. Defaults to 500000. top_k (int, optional): For each query, we extract the top_k most similar pairs and add it to a sorted list. I.e., for one sentence we cannot find more than top_k paraphrases. Defaults to 100. show_progress_bar (bool, optional): Output a progress bar. Defaults to False. batch_size (int, optional): Batch size for computing sentence embeddings. Defaults to 16. name (str, optional): Name of the experiment. Defaults to "". write_csv (bool, optional): Write results to CSV file. Defaults to True. truncate_dim (Optional[int], optional): The dimension to truncate sentence embeddings to. `None` uses the model's current truncation dimension. Defaults to None. """ super().__init__() self.sentences = [] self.ids = [] for id, sentence in sentences_map.items(): self.sentences.append(sentence) self.ids.append(id) self.name = name self.show_progress_bar = show_progress_bar self.batch_size = batch_size self.query_chunk_size = query_chunk_size self.corpus_chunk_size = corpus_chunk_size self.max_pairs = max_pairs self.top_k = top_k self.truncate_dim = truncate_dim self.duplicates = duplicates_dict if duplicates_dict is not None else defaultdict(lambda: defaultdict(bool)) if duplicates_list is not None: for id1, id2 in duplicates_list: if id1 in sentences_map and id2 in sentences_map: self.duplicates[id1][id2] = True self.duplicates[id2][id1] = True # Add transitive closure if add_transitive_closure: self.duplicates = self.add_transitive_closure(self.duplicates) positive_key_pairs = set() for key1 in self.duplicates: for key2 in self.duplicates[key1]: if ( key1 in sentences_map and key2 in sentences_map and (self.duplicates[key1][key2] or self.duplicates[key2][key1]) ): positive_key_pairs.add(tuple(sorted([key1, key2]))) self.total_num_duplicates = len(positive_key_pairs) if name: name = "_" + name self.csv_file: str = "paraphrase_mining_evaluation" + name + "_results.csv" self.csv_headers = ["epoch", "steps", "precision", "recall", "f1", "threshold", "average_precision"] self.write_csv = write_csv self.primary_metric = "average_precision" def __call__( self, model: SentenceTransformer, output_path: str = None, epoch: int = -1, steps: int = -1 ) -> dict[str, float]: if epoch != -1: if steps == -1: out_txt = f" after epoch {epoch}" else: out_txt = f" in epoch {epoch} after {steps} steps" else: out_txt = "" if self.truncate_dim is not None: out_txt += f" (truncated to {self.truncate_dim})" logger.info(f"Paraphrase Mining Evaluation of the model on the {self.name} dataset{out_txt}:") # Compute embedding for the sentences with nullcontext() if self.truncate_dim is None else model.truncate_sentence_embeddings(self.truncate_dim): pairs_list = paraphrase_mining( model, self.sentences, self.show_progress_bar, self.batch_size, self.query_chunk_size, self.corpus_chunk_size, self.max_pairs, self.top_k, ) logger.info("Number of candidate pairs: " + str(len(pairs_list))) # Compute F1 score and Average Precision n_extract = n_correct = 0 threshold = 0 best_f1 = best_recall = best_precision = 0 average_precision = 0 for idx in range(len(pairs_list)): score, i, j = pairs_list[idx] id1 = self.ids[i] id2 = self.ids[j] # Compute optimal threshold and F1-score n_extract += 1 if self.duplicates[id1][id2] or self.duplicates[id2][id1]: n_correct += 1 precision = n_correct / n_extract recall = n_correct / self.total_num_duplicates f1 = 2 * precision * recall / (precision + recall) average_precision += precision if f1 > best_f1: best_f1 = f1 best_precision = precision best_recall = recall threshold = (pairs_list[idx][0] + pairs_list[min(idx + 1, len(pairs_list) - 1)][0]) / 2 average_precision = average_precision / self.total_num_duplicates logger.info(f"Average Precision: {average_precision * 100:.2f}") logger.info(f"Optimal threshold: {threshold:.4f}") logger.info(f"Precision: {best_precision * 100:.2f}") logger.info(f"Recall: {best_recall * 100:.2f}") logger.info(f"F1: {best_f1 * 100:.2f}\n") if output_path is not None and self.write_csv: csv_path = os.path.join(output_path, self.csv_file) if not os.path.isfile(csv_path): with open(csv_path, newline="", mode="w", encoding="utf-8") as f: writer = csv.writer(f) writer.writerow(self.csv_headers) writer.writerow([epoch, steps, best_precision, best_recall, best_f1, threshold, average_precision]) else: with open(csv_path, newline="", mode="a", encoding="utf-8") as f: writer = csv.writer(f) writer.writerow([epoch, steps, best_precision, best_recall, best_f1, threshold, average_precision]) metrics = { "average_precision": average_precision, "f1": best_f1, "precision": best_precision, "recall": best_recall, "threshold": threshold, } metrics = self.prefix_name_to_metrics(metrics, self.name) self.store_metrics_in_model_card_data(model, metrics) return metrics @staticmethod def add_transitive_closure(graph): nodes_visited = set() for a in list(graph.keys()): if a not in nodes_visited: connected_subgraph_nodes = set() connected_subgraph_nodes.add(a) # Add all nodes in the connected graph neighbor_nodes_queue = list(graph[a]) while len(neighbor_nodes_queue) > 0: node = neighbor_nodes_queue.pop(0) if node not in connected_subgraph_nodes: connected_subgraph_nodes.add(node) neighbor_nodes_queue.extend(graph[node]) # Ensure transitivity between all nodes in the graph connected_subgraph_nodes = list(connected_subgraph_nodes) for i in range(len(connected_subgraph_nodes) - 1): for j in range(i + 1, len(connected_subgraph_nodes)): graph[connected_subgraph_nodes[i]][connected_subgraph_nodes[j]] = True graph[connected_subgraph_nodes[j]][connected_subgraph_nodes[i]] = True nodes_visited.add(connected_subgraph_nodes[i]) nodes_visited.add(connected_subgraph_nodes[j]) return graph