sentence_transformers.cross_encoder.evaluation.CESoftmaxAccuracyEvaluator 源代码

from __future__ import annotations

import csv
import logging
import os

import numpy as np

from sentence_transformers import InputExample

logger = logging.getLogger(__name__)


[文档] class CESoftmaxAccuracyEvaluator: """ This evaluator can be used with the CrossEncoder class. It is designed for CrossEncoders with 2 or more outputs. It measure the accuracy of the predict class vs. the gold labels. """ def __init__(self, sentence_pairs: list[list[str]], labels: list[int], name: str = "", write_csv: bool = True): self.sentence_pairs = sentence_pairs self.labels = labels self.name = name self.csv_file = "CESoftmaxAccuracyEvaluator" + ("_" + name if name else "") + "_results.csv" self.csv_headers = ["epoch", "steps", "Accuracy"] self.write_csv = write_csv @classmethod def from_input_examples(cls, examples: list[InputExample], **kwargs): sentence_pairs = [] labels = [] for example in examples: sentence_pairs.append(example.texts) labels.append(example.label) return cls(sentence_pairs, labels, **kwargs) def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float: if epoch != -1: if steps == -1: out_txt = f" after epoch {epoch}:" else: out_txt = f" in epoch {epoch} after {steps} steps:" else: out_txt = ":" logger.info("CESoftmaxAccuracyEvaluator: Evaluating the model on " + self.name + " dataset" + out_txt) pred_scores = model.predict(self.sentence_pairs, convert_to_numpy=True, show_progress_bar=False) pred_labels = np.argmax(pred_scores, axis=1) assert len(pred_labels) == len(self.labels) acc = np.sum(pred_labels == self.labels) / len(self.labels) logger.info(f"Accuracy: {acc * 100:.2f}") if output_path is not None and self.write_csv: csv_path = os.path.join(output_path, self.csv_file) output_file_exists = os.path.isfile(csv_path) with open(csv_path, mode="a" if output_file_exists else "w", encoding="utf-8") as f: writer = csv.writer(f) if not output_file_exists: writer.writerow(self.csv_headers) writer.writerow([epoch, steps, acc]) return acc