sentence_transformers.cross_encoder.evaluation.CEBinaryClassificationEvaluator 源代码

from __future__ import annotations

import csv
import logging
import os

import numpy as np
from sklearn.metrics import average_precision_score

from sentence_transformers import InputExample
from sentence_transformers.evaluation import BinaryClassificationEvaluator

logger = logging.getLogger(__name__)


[文档] class CEBinaryClassificationEvaluator: """ This evaluator can be used with the CrossEncoder class. Given sentence pairs and binary labels (0 and 1), it compute the average precision and the best possible f1 score """ def __init__( self, sentence_pairs: list[list[str]], labels: list[int], name: str = "", show_progress_bar: bool = False, write_csv: bool = True, ): assert len(sentence_pairs) == len(labels) for label in labels: assert label == 0 or label == 1 self.sentence_pairs = sentence_pairs self.labels = np.asarray(labels) self.name = name if show_progress_bar is None: show_progress_bar = ( logger.getEffectiveLevel() == logging.INFO or logger.getEffectiveLevel() == logging.DEBUG ) self.show_progress_bar = show_progress_bar self.csv_file = "CEBinaryClassificationEvaluator" + ("_" + name if name else "") + "_results.csv" self.csv_headers = [ "epoch", "steps", "Accuracy", "Accuracy_Threshold", "F1", "F1_Threshold", "Precision", "Recall", "Average_Precision", ] self.write_csv = write_csv @classmethod def from_input_examples(cls, examples: list[InputExample], **kwargs): sentence_pairs = [] labels = [] for example in examples: sentence_pairs.append(example.texts) labels.append(example.label) return cls(sentence_pairs, labels, **kwargs) def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float: if epoch != -1: if steps == -1: out_txt = f" after epoch {epoch}:" else: out_txt = f" in epoch {epoch} after {steps} steps:" else: out_txt = ":" logger.info("CEBinaryClassificationEvaluator: Evaluating the model on " + self.name + " dataset" + out_txt) pred_scores = model.predict( self.sentence_pairs, convert_to_numpy=True, show_progress_bar=self.show_progress_bar ) acc, acc_threshold = BinaryClassificationEvaluator.find_best_acc_and_threshold(pred_scores, self.labels, True) f1, precision, recall, f1_threshold = BinaryClassificationEvaluator.find_best_f1_and_threshold( pred_scores, self.labels, True ) ap = average_precision_score(self.labels, pred_scores) logger.info(f"Accuracy: {acc * 100:.2f}\t(Threshold: {acc_threshold:.4f})") logger.info(f"F1: {f1 * 100:.2f}\t(Threshold: {f1_threshold:.4f})") logger.info(f"Precision: {precision * 100:.2f}") logger.info(f"Recall: {recall * 100:.2f}") logger.info(f"Average Precision: {ap * 100:.2f}\n") if output_path is not None and self.write_csv: csv_path = os.path.join(output_path, self.csv_file) output_file_exists = os.path.isfile(csv_path) with open(csv_path, mode="a" if output_file_exists else "w", encoding="utf-8") as f: writer = csv.writer(f) if not output_file_exists: writer.writerow(self.csv_headers) writer.writerow([epoch, steps, acc, acc_threshold, f1, f1_threshold, precision, recall, ap]) return ap