Source code for langchain_community.callbacks.tracers.comet

from types import ModuleType, SimpleNamespace
from typing import TYPE_CHECKING, Any, Callable, Dict

from langchain_core.tracers import BaseTracer
from langchain_core.utils import guard_import

if TYPE_CHECKING:
    from uuid import UUID

    from comet_llm import Span
    from comet_llm.chains.chain import Chain

    from langchain_community.callbacks.tracers.schemas import Run


def _get_run_type(run: "Run") -> str:
    if isinstance(run.run_type, str):
        return run.run_type
    elif hasattr(run.run_type, "value"):
        return run.run_type.value
    else:
        return str(run.run_type)


[docs]def import_comet_llm_api() -> SimpleNamespace: """导入comet_llm api,并在未安装时引发错误。""" comet_llm = guard_import("comet_llm") comet_llm_chains = guard_import("comet_llm.chains") return SimpleNamespace( chain=comet_llm_chains.chain, span=comet_llm_chains.span, chain_api=comet_llm_chains.api, experiment_info=comet_llm.experiment_info, flush=comet_llm.flush, )
[docs]class CometTracer(BaseTracer): """彗星跟踪器。"""
[docs] def __init__(self, **kwargs: Any) -> None: """初始化Comet跟踪器。""" super().__init__(**kwargs) self._span_map: Dict["UUID", "Span"] = {} """Map from run id to span.""" self._chains_map: Dict["UUID", "Chain"] = {} """Map from run id to chain.""" self._initialize_comet_modules()
def _initialize_comet_modules(self) -> None: comet_llm_api = import_comet_llm_api() self._chain: ModuleType = comet_llm_api.chain self._span: ModuleType = comet_llm_api.span self._chain_api: ModuleType = comet_llm_api.chain_api self._experiment_info: ModuleType = comet_llm_api.experiment_info self._flush: Callable[[], None] = comet_llm_api.flush def _persist_run(self, run: "Run") -> None: run_dict: Dict[str, Any] = run.dict() chain_ = self._chains_map[run.id] chain_.set_outputs(outputs=run_dict["outputs"]) self._chain_api.log_chain(chain_) def _process_start_trace(self, run: "Run") -> None: run_dict: Dict[str, Any] = run.dict() if not run.parent_run_id: # This is the first run, which maps to a chain chain_: "Chain" = self._chain.Chain( inputs=run_dict["inputs"], metadata=None, experiment_info=self._experiment_info.get(), ) self._chains_map[run.id] = chain_ else: span: "Span" = self._span.Span( inputs=run_dict["inputs"], category=_get_run_type(run), metadata=run_dict["extra"], name=run.name, ) span.__api__start__(self._chains_map[run.parent_run_id]) self._chains_map[run.id] = self._chains_map[run.parent_run_id] self._span_map[run.id] = span def _process_end_trace(self, run: "Run") -> None: run_dict: Dict[str, Any] = run.dict() if not run.parent_run_id: pass # Langchain will call _persist_run for us else: span = self._span_map[run.id] span.set_outputs(outputs=run_dict["outputs"]) span.__api__end__()
[docs] def flush(self) -> None: self._flush()
def _on_llm_start(self, run: "Run") -> None: """处理LLM运行时的启动。""" self._process_start_trace(run) def _on_llm_end(self, run: "Run") -> None: """处理LLM运行。""" self._process_end_trace(run) def _on_llm_error(self, run: "Run") -> None: """处理LLM运行中的错误。""" self._process_end_trace(run) def _on_chain_start(self, run: "Run") -> None: """处理链式运行的开始。""" self._process_start_trace(run) def _on_chain_end(self, run: "Run") -> None: """处理链式运行。""" self._process_end_trace(run) def _on_chain_error(self, run: "Run") -> None: """处理链式运行中的错误。""" self._process_end_trace(run) def _on_tool_start(self, run: "Run") -> None: """处理工具启动时的运行。""" self._process_start_trace(run) def _on_tool_end(self, run: "Run") -> None: """处理工具运行。""" self._process_end_trace(run) def _on_tool_error(self, run: "Run") -> None: """处理工具运行时的错误。""" self._process_end_trace(run)