"""一个在完成的运行上运行评估器的跟踪器。"""
from __future__ import annotations
import logging
import threading
import weakref
from concurrent.futures import Future, ThreadPoolExecutor, wait
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
from uuid import UUID
import langsmith
from langsmith.evaluation.evaluator import EvaluationResult, EvaluationResults
from langchain_core.tracers import langchain as langchain_tracer
from langchain_core.tracers.base import BaseTracer
from langchain_core.tracers.context import tracing_v2_enabled
from langchain_core.tracers.langchain import _get_executor
from langchain_core.tracers.schemas import Run
logger = logging.getLogger(__name__)
_TRACERS: weakref.WeakSet[EvaluatorCallbackHandler] = weakref.WeakSet()
[docs]def wait_for_all_evaluators() -> None:
"""等待所有跟踪器完成。"""
global _TRACERS
for tracer in list(_TRACERS):
if tracer is not None:
tracer.wait_for_futures()
[docs]class EvaluatorCallbackHandler(BaseTracer):
"""运行评估器的跟踪器,每当运行被持久化时运行评估器。
参数
----------
evaluators : Sequence[RunEvaluator]
应用于所有顶级运行的运行评估器。
client : LangSmith Client, optional
用于评估运行的LangSmith客户端实例。
如果未指定,将创建一个新实例。
example_id : Union[UUID, str], optional
与运行关联的示例ID。
project_name : str, optional
用于组织评估链运行的LangSmith项目名称。
属性
----------
example_id : Union[UUID, None]
与运行关联的示例ID。
client : Client
用于评估运行的LangSmith客户端实例。
evaluators : Sequence[RunEvaluator]
要执行的运行评估器序列。
executor : ThreadPoolExecutor
用于运行评估器的线程池执行程序。
futures : Set[Future]
代表正在运行的评估器的未来集合。
skip_unfinished : bool
是否跳过尚未完成或引发错误的运行。
project_name : Optional[str]
用于组织评估链运行的LangSmith项目名称。"""
name = "evaluator_callback_handler"
[docs] def __init__(
self,
evaluators: Sequence[langsmith.RunEvaluator],
client: Optional[langsmith.Client] = None,
example_id: Optional[Union[UUID, str]] = None,
skip_unfinished: bool = True,
project_name: Optional[str] = "evaluators",
max_concurrency: Optional[int] = None,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.example_id = (
UUID(example_id) if isinstance(example_id, str) else example_id
)
self.client = client or langchain_tracer.get_client()
self.evaluators = evaluators
if max_concurrency is None:
self.executor: Optional[ThreadPoolExecutor] = _get_executor()
elif max_concurrency > 0:
self.executor = ThreadPoolExecutor(max_workers=max_concurrency)
weakref.finalize(
self,
lambda: cast(ThreadPoolExecutor, self.executor).shutdown(wait=True),
)
else:
self.executor = None
self.futures: weakref.WeakSet[Future] = weakref.WeakSet()
self.skip_unfinished = skip_unfinished
self.project_name = project_name
self.logged_eval_results: Dict[Tuple[str, str], List[EvaluationResult]] = {}
self.lock = threading.Lock()
global _TRACERS
_TRACERS.add(self)
def _evaluate_in_project(self, run: Run, evaluator: langsmith.RunEvaluator) -> None:
"""评估项目中的运行。
参数
----------
run:运行
要评估的运行。
evaluator:RunEvaluator
用于评估运行的评估器。
"""
try:
if self.project_name is None:
eval_result = self.client.evaluate_run(run, evaluator)
eval_results = [eval_result]
with tracing_v2_enabled(
project_name=self.project_name, tags=["eval"], client=self.client
) as cb:
reference_example = (
self.client.read_example(run.reference_example_id)
if run.reference_example_id
else None
)
evaluation_result = evaluator.evaluate_run(
# This is subclass, but getting errors for some reason
run, # type: ignore
example=reference_example,
)
eval_results = self._log_evaluation_feedback(
evaluation_result,
run,
source_run_id=cb.latest_run.id if cb.latest_run else None,
)
except Exception as e:
logger.error(
f"Error evaluating run {run.id} with "
f"{evaluator.__class__.__name__}: {repr(e)}",
exc_info=True,
)
raise e
example_id = str(run.reference_example_id)
with self.lock:
for res in eval_results:
run_id = (
str(getattr(res, "target_run_id"))
if hasattr(res, "target_run_id")
else str(run.id)
)
self.logged_eval_results.setdefault((run_id, example_id), []).append(
res
)
def _select_eval_results(
self,
results: Union[EvaluationResult, EvaluationResults],
) -> List[EvaluationResult]:
if isinstance(results, EvaluationResult):
results_ = [results]
elif isinstance(results, dict) and "results" in results:
results_ = cast(List[EvaluationResult], results["results"])
else:
raise TypeError(
f"Invalid evaluation result type {type(results)}."
" Expected EvaluationResult or EvaluationResults."
)
return results_
def _log_evaluation_feedback(
self,
evaluator_response: Union[EvaluationResult, EvaluationResults],
run: Run,
source_run_id: Optional[UUID] = None,
) -> List[EvaluationResult]:
results = self._select_eval_results(evaluator_response)
for res in results:
source_info_: Dict[str, Any] = {}
if res.evaluator_info:
source_info_ = {**res.evaluator_info, **source_info_}
run_id_ = (
getattr(res, "target_run_id")
if hasattr(res, "target_run_id") and res.target_run_id is not None
else run.id
)
self.client.create_feedback(
run_id_,
res.key,
score=res.score,
value=res.value,
comment=res.comment,
correction=res.correction,
source_info=source_info_,
source_run_id=res.source_run_id or source_run_id,
feedback_source_type=langsmith.schemas.FeedbackSourceType.MODEL,
)
return results
def _persist_run(self, run: Run) -> None:
"""在运行上运行评估器。
参数
----------
run : Run
要评估的运行。
"""
if self.skip_unfinished and not run.outputs:
logger.debug(f"Skipping unfinished run {run.id}")
return
run_ = run.copy()
run_.reference_example_id = self.example_id
for evaluator in self.evaluators:
if self.executor is None:
self._evaluate_in_project(run_, evaluator)
else:
self.futures.add(
self.executor.submit(self._evaluate_in_project, run_, evaluator)
)
[docs] def wait_for_futures(self) -> None:
"""等待所有future完成。"""
wait(self.futures)