Source code for langchain_community.callbacks.comet_ml_callback

import tempfile
from copy import deepcopy
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Sequence

from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import Generation, LLMResult
from langchain_core.utils import guard_import

import langchain_community
from langchain_community.callbacks.utils import (
    BaseMetadataCallbackHandler,
    flatten_dict,
    import_pandas,
    import_spacy,
    import_textstat,
)

LANGCHAIN_MODEL_NAME = "langchain-model"


[docs]def import_comet_ml() -> Any: """导入comet_ml,并在未安装时引发错误。""" return guard_import("comet_ml")
def _get_experiment( workspace: Optional[str] = None, project_name: Optional[str] = None ) -> Any: comet_ml = import_comet_ml() experiment = comet_ml.Experiment( workspace=workspace, project_name=project_name, ) return experiment def _fetch_text_complexity_metrics(text: str) -> dict: textstat = import_textstat() text_complexity_metrics = { "flesch_reading_ease": textstat.flesch_reading_ease(text), "flesch_kincaid_grade": textstat.flesch_kincaid_grade(text), "smog_index": textstat.smog_index(text), "coleman_liau_index": textstat.coleman_liau_index(text), "automated_readability_index": textstat.automated_readability_index(text), "dale_chall_readability_score": textstat.dale_chall_readability_score(text), "difficult_words": textstat.difficult_words(text), "linsear_write_formula": textstat.linsear_write_formula(text), "gunning_fog": textstat.gunning_fog(text), "text_standard": textstat.text_standard(text), "fernandez_huerta": textstat.fernandez_huerta(text), "szigriszt_pazos": textstat.szigriszt_pazos(text), "gutierrez_polini": textstat.gutierrez_polini(text), "crawford": textstat.crawford(text), "gulpease_index": textstat.gulpease_index(text), "osman": textstat.osman(text), } return text_complexity_metrics def _summarize_metrics_for_generated_outputs(metrics: Sequence) -> dict: pd = import_pandas() metrics_df = pd.DataFrame(metrics) metrics_summary = metrics_df.describe() return metrics_summary.to_dict()
[docs]class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): """回调处理程序,用于记录到Comet。 参数: job_type (str): Comet_ml任务类型,例如“推断”、“测试”或“质控” project_name (str): Comet_ml项目名称 tags (list): 要添加到任务的标签 task_name (str): Comet_ml任务的名称 visualize (bool): 是否可视化运行。 complexity_metrics (bool): 是否记录复杂度指标 stream_logs (bool): 是否将回调操作流式传输到Comet 此处理程序将利用相关的回调方法,并使用关于LLM运行状态的元数据格式化每个回调函数的输入,并将响应添加到{method}_records和action的记录列表中。然后将响应记录到Comet。"""
[docs] def __init__( self, task_type: Optional[str] = "inference", workspace: Optional[str] = None, project_name: Optional[str] = None, tags: Optional[Sequence] = None, name: Optional[str] = None, visualizations: Optional[List[str]] = None, complexity_metrics: bool = False, custom_metrics: Optional[Callable] = None, stream_logs: bool = True, ) -> None: """初始化回调处理程序。""" self.comet_ml = import_comet_ml() super().__init__() self.task_type = task_type self.workspace = workspace self.project_name = project_name self.tags = tags self.visualizations = visualizations self.complexity_metrics = complexity_metrics self.custom_metrics = custom_metrics self.stream_logs = stream_logs self.temp_dir = tempfile.TemporaryDirectory() self.experiment = _get_experiment(workspace, project_name) self.experiment.log_other("Created from", "langchain") if tags: self.experiment.add_tags(tags) self.name = name if self.name: self.experiment.set_name(self.name) warning = ( "The comet_ml callback is currently in beta and is subject to change " "based on updates to `langchain`. Please report any issues to " "https://github.com/comet-ml/issue-tracking/issues with the tag " "`langchain`." ) self.comet_ml.LOGGER.warning(warning) self.callback_columns: list = [] self.action_records: list = [] self.complexity_metrics = complexity_metrics if self.visualizations: spacy = import_spacy() self.nlp = spacy.load("en_core_web_sm") else: self.nlp = None
def _init_resp(self) -> Dict: return {k: None for k in self.callback_columns}
[docs] def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any ) -> None: """LLM启动时运行。""" self.step += 1 self.llm_starts += 1 self.starts += 1 metadata = self._init_resp() metadata.update({"action": "on_llm_start"}) metadata.update(flatten_dict(serialized)) metadata.update(self.get_custom_callback_meta()) for prompt in prompts: prompt_resp = deepcopy(metadata) prompt_resp["prompts"] = prompt self.on_llm_start_records.append(prompt_resp) self.action_records.append(prompt_resp) if self.stream_logs: self._log_stream(prompt, metadata, self.step)
[docs] def on_llm_new_token(self, token: str, **kwargs: Any) -> None: """当LLM生成一个新的令牌时运行。""" self.step += 1 self.llm_streams += 1 resp = self._init_resp() resp.update({"action": "on_llm_new_token", "token": token}) resp.update(self.get_custom_callback_meta()) self.action_records.append(resp)
[docs] def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: """LLM 运行结束时运行。""" self.step += 1 self.llm_ends += 1 self.ends += 1 metadata = self._init_resp() metadata.update({"action": "on_llm_end"}) metadata.update(flatten_dict(response.llm_output or {})) metadata.update(self.get_custom_callback_meta()) output_complexity_metrics = [] output_custom_metrics = [] for prompt_idx, generations in enumerate(response.generations): for gen_idx, generation in enumerate(generations): text = generation.text generation_resp = deepcopy(metadata) generation_resp.update(flatten_dict(generation.dict())) complexity_metrics = self._get_complexity_metrics(text) if complexity_metrics: output_complexity_metrics.append(complexity_metrics) generation_resp.update(complexity_metrics) custom_metrics = self._get_custom_metrics( generation, prompt_idx, gen_idx ) if custom_metrics: output_custom_metrics.append(custom_metrics) generation_resp.update(custom_metrics) if self.stream_logs: self._log_stream(text, metadata, self.step) self.action_records.append(generation_resp) self.on_llm_end_records.append(generation_resp) self._log_text_metrics(output_complexity_metrics, step=self.step) self._log_text_metrics(output_custom_metrics, step=self.step)
[docs] def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: """当LLM出现错误时运行。""" self.step += 1 self.errors += 1
[docs] def on_chain_start( self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any ) -> None: """当链开始运行时运行。""" self.step += 1 self.chain_starts += 1 self.starts += 1 resp = self._init_resp() resp.update({"action": "on_chain_start"}) resp.update(flatten_dict(serialized)) resp.update(self.get_custom_callback_meta()) for chain_input_key, chain_input_val in inputs.items(): if isinstance(chain_input_val, str): input_resp = deepcopy(resp) if self.stream_logs: self._log_stream(chain_input_val, resp, self.step) input_resp.update({chain_input_key: chain_input_val}) self.action_records.append(input_resp) else: self.comet_ml.LOGGER.warning( f"Unexpected data format provided! " f"Input Value for {chain_input_key} will not be logged" )
[docs] def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: """当链结束运行时运行。""" self.step += 1 self.chain_ends += 1 self.ends += 1 resp = self._init_resp() resp.update({"action": "on_chain_end"}) resp.update(self.get_custom_callback_meta()) for chain_output_key, chain_output_val in outputs.items(): if isinstance(chain_output_val, str): output_resp = deepcopy(resp) if self.stream_logs: self._log_stream(chain_output_val, resp, self.step) output_resp.update({chain_output_key: chain_output_val}) self.action_records.append(output_resp) else: self.comet_ml.LOGGER.warning( f"Unexpected data format provided! " f"Output Value for {chain_output_key} will not be logged" )
[docs] def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: """当链式错误时运行。""" self.step += 1 self.errors += 1
[docs] def on_tool_start( self, serialized: Dict[str, Any], input_str: str, **kwargs: Any ) -> None: """当工具开始运行时运行。""" self.step += 1 self.tool_starts += 1 self.starts += 1 resp = self._init_resp() resp.update({"action": "on_tool_start"}) resp.update(flatten_dict(serialized)) resp.update(self.get_custom_callback_meta()) if self.stream_logs: self._log_stream(input_str, resp, self.step) resp.update({"input_str": input_str}) self.action_records.append(resp)
[docs] def on_tool_end(self, output: Any, **kwargs: Any) -> None: """当工具运行结束时运行。""" output = str(output) self.step += 1 self.tool_ends += 1 self.ends += 1 resp = self._init_resp() resp.update({"action": "on_tool_end"}) resp.update(self.get_custom_callback_meta()) if self.stream_logs: self._log_stream(output, resp, self.step) resp.update({"output": output}) self.action_records.append(resp)
[docs] def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: """当工具出现错误时运行。""" self.step += 1 self.errors += 1
[docs] def on_text(self, text: str, **kwargs: Any) -> None: """ 当代理程序结束时运行。 """ self.step += 1 self.text_ctr += 1 resp = self._init_resp() resp.update({"action": "on_text"}) resp.update(self.get_custom_callback_meta()) if self.stream_logs: self._log_stream(text, resp, self.step) resp.update({"text": text}) self.action_records.append(resp)
[docs] def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: """当代理程序运行结束时运行。""" self.step += 1 self.agent_ends += 1 self.ends += 1 resp = self._init_resp() output = finish.return_values["output"] log = finish.log resp.update({"action": "on_agent_finish", "log": log}) resp.update(self.get_custom_callback_meta()) if self.stream_logs: self._log_stream(output, resp, self.step) resp.update({"output": output}) self.action_records.append(resp)
[docs] def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: """在代理程序上运行的操作。""" self.step += 1 self.tool_starts += 1 self.starts += 1 tool = action.tool tool_input = str(action.tool_input) log = action.log resp = self._init_resp() resp.update({"action": "on_agent_action", "log": log, "tool": tool}) resp.update(self.get_custom_callback_meta()) if self.stream_logs: self._log_stream(tool_input, resp, self.step) resp.update({"tool_input": tool_input}) self.action_records.append(resp)
def _get_complexity_metrics(self, text: str) -> dict: """使用textstat计算文本复杂性指标。 参数: text (str): 要分析的文本。 返回值: (dict): 包含复杂性指标的字典。 """ resp = {} if self.complexity_metrics: text_complexity_metrics = _fetch_text_complexity_metrics(text) resp.update(text_complexity_metrics) return resp def _get_custom_metrics( self, generation: Generation, prompt_idx: int, gen_idx: int ) -> dict: """为LLM生成的输出计算自定义指标 参数: generation (LLMResult): LLM生成的输出 prompt_idx (int): 输入提示的列表索引 gen_idx (int): 生成输出的列表索引 返回: dict: 包含自定义指标的字典。 """ resp = {} if self.custom_metrics: custom_metrics = self.custom_metrics(generation, prompt_idx, gen_idx) resp.update(custom_metrics) return resp
[docs] def flush_tracker( self, langchain_asset: Any = None, task_type: Optional[str] = "inference", workspace: Optional[str] = None, project_name: Optional[str] = "comet-langchain-demo", tags: Optional[Sequence] = None, name: Optional[str] = None, visualizations: Optional[List[str]] = None, complexity_metrics: bool = False, custom_metrics: Optional[Callable] = None, finish: bool = False, reset: bool = False, ) -> None: """刷新追踪器并设置会话。 此后的所有内容将成为一个新表。 参数: name:迄今为止执行会话的名称,以便识别 langchain_asset:要保存的langchain资产。 finish:是否完成运行。 返回: 无。 """ self._log_session(langchain_asset) if langchain_asset: try: self._log_model(langchain_asset) except Exception: self.comet_ml.LOGGER.error( "Failed to export agent or LLM to Comet", exc_info=True, extra={"show_traceback": True}, ) if finish: self.experiment.end() if reset: self._reset( task_type, workspace, project_name, tags, name, visualizations, complexity_metrics, custom_metrics, )
def _log_stream(self, prompt: str, metadata: dict, step: int) -> None: self.experiment.log_text(prompt, metadata=metadata, step=step) def _log_model(self, langchain_asset: Any) -> None: model_parameters = self._get_llm_parameters(langchain_asset) self.experiment.log_parameters(model_parameters, prefix="model") langchain_asset_path = Path(self.temp_dir.name, "model.json") model_name = self.name if self.name else LANGCHAIN_MODEL_NAME try: if hasattr(langchain_asset, "save"): langchain_asset.save(langchain_asset_path) self.experiment.log_model(model_name, str(langchain_asset_path)) except (ValueError, AttributeError, NotImplementedError) as e: if hasattr(langchain_asset, "save_agent"): langchain_asset.save_agent(langchain_asset_path) self.experiment.log_model(model_name, str(langchain_asset_path)) else: self.comet_ml.LOGGER.error( f"{e}" " Could not save Langchain Asset " f"for {langchain_asset.__class__.__name__}" ) def _log_session(self, langchain_asset: Optional[Any] = None) -> None: try: llm_session_df = self._create_session_analysis_dataframe(langchain_asset) # Log the cleaned dataframe as a table self.experiment.log_table("langchain-llm-session.csv", llm_session_df) except Exception: self.comet_ml.LOGGER.warning( "Failed to log session data to Comet", exc_info=True, extra={"show_traceback": True}, ) try: metadata = {"langchain_version": str(langchain_community.__version__)} # Log the langchain low-level records as a JSON file directly self.experiment.log_asset_data( self.action_records, "langchain-action_records.json", metadata=metadata ) except Exception: self.comet_ml.LOGGER.warning( "Failed to log session data to Comet", exc_info=True, extra={"show_traceback": True}, ) try: self._log_visualizations(llm_session_df) except Exception: self.comet_ml.LOGGER.warning( "Failed to log visualizations to Comet", exc_info=True, extra={"show_traceback": True}, ) def _log_text_metrics(self, metrics: Sequence[dict], step: int) -> None: if not metrics: return metrics_summary = _summarize_metrics_for_generated_outputs(metrics) for key, value in metrics_summary.items(): self.experiment.log_metrics(value, prefix=key, step=step) def _log_visualizations(self, session_df: Any) -> None: if not (self.visualizations and self.nlp): return spacy = import_spacy() prompts = session_df["prompts"].tolist() outputs = session_df["text"].tolist() for idx, (prompt, output) in enumerate(zip(prompts, outputs)): doc = self.nlp(output) sentence_spans = list(doc.sents) for visualization in self.visualizations: try: html = spacy.displacy.render( sentence_spans, style=visualization, options={"compact": True}, jupyter=False, page=True, ) self.experiment.log_asset_data( html, name=f"langchain-viz-{visualization}-{idx}.html", metadata={"prompt": prompt}, step=idx, ) except Exception as e: self.comet_ml.LOGGER.warning( e, exc_info=True, extra={"show_traceback": True} ) return def _reset( self, task_type: Optional[str] = None, workspace: Optional[str] = None, project_name: Optional[str] = None, tags: Optional[Sequence] = None, name: Optional[str] = None, visualizations: Optional[List[str]] = None, complexity_metrics: bool = False, custom_metrics: Optional[Callable] = None, ) -> None: _task_type = task_type if task_type else self.task_type _workspace = workspace if workspace else self.workspace _project_name = project_name if project_name else self.project_name _tags = tags if tags else self.tags _name = name if name else self.name _visualizations = visualizations if visualizations else self.visualizations _complexity_metrics = ( complexity_metrics if complexity_metrics else self.complexity_metrics ) _custom_metrics = custom_metrics if custom_metrics else self.custom_metrics self.__init__( # type: ignore task_type=_task_type, workspace=_workspace, project_name=_project_name, tags=_tags, name=_name, visualizations=_visualizations, complexity_metrics=_complexity_metrics, custom_metrics=_custom_metrics, ) self.reset_callback_meta() self.temp_dir = tempfile.TemporaryDirectory() def _create_session_analysis_dataframe(self, langchain_asset: Any = None) -> dict: pd = import_pandas() llm_parameters = self._get_llm_parameters(langchain_asset) num_generations_per_prompt = llm_parameters.get("n", 1) llm_start_records_df = pd.DataFrame(self.on_llm_start_records) # Repeat each input row based on the number of outputs generated per prompt llm_start_records_df = llm_start_records_df.loc[ llm_start_records_df.index.repeat(num_generations_per_prompt) ].reset_index(drop=True) llm_end_records_df = pd.DataFrame(self.on_llm_end_records) llm_session_df = pd.merge( llm_start_records_df, llm_end_records_df, left_index=True, right_index=True, suffixes=["_llm_start", "_llm_end"], ) return llm_session_df def _get_llm_parameters(self, langchain_asset: Any = None) -> dict: if not langchain_asset: return {} try: if hasattr(langchain_asset, "agent"): llm_parameters = langchain_asset.agent.llm_chain.llm.dict() elif hasattr(langchain_asset, "llm_chain"): llm_parameters = langchain_asset.llm_chain.llm.dict() elif hasattr(langchain_asset, "llm"): llm_parameters = langchain_asset.llm.dict() else: llm_parameters = langchain_asset.dict() except Exception: return {} return llm_parameters