Source code for langchain_community.callbacks.clearml_callback

from __future__ import annotations

import tempfile
from copy import deepcopy
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Sequence

from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult
from langchain_core.utils import guard_import

from langchain_community.callbacks.utils import (
    BaseMetadataCallbackHandler,
    flatten_dict,
    hash_string,
    import_pandas,
    import_spacy,
    import_textstat,
    load_json,
)

if TYPE_CHECKING:
    import pandas as pd


[docs]def import_clearml() -> Any: """导入clearml python包,并在未安装时引发错误。""" return guard_import("clearml")
[docs]class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): """回调处理程序,用于记录到ClearML。 参数: job_type (str): ClearML任务类型,如“推断”、“测试”或“质控” project_name (str): ClearML项目名称 tags (list): 要添加到任务的标签 task_name (str): ClearML任务名称 visualize (bool): 是否可视化运行。 complexity_metrics (bool): 是否记录复杂度指标 stream_logs (bool): 是否将回调操作流式传输到ClearML 此处理程序将利用关联的回调方法和格式化每个回调函数的输入,其中包含有关LLM运行状态的元数据,并将响应添加到{method}_records和操作的记录列表中。然后将响应记录到ClearML控制台。"""
[docs] def __init__( self, task_type: Optional[str] = "inference", project_name: Optional[str] = "langchain_callback_demo", tags: Optional[Sequence] = None, task_name: Optional[str] = None, visualize: bool = False, complexity_metrics: bool = False, stream_logs: bool = False, ) -> None: """初始化回调处理程序。""" clearml = import_clearml() spacy = import_spacy() super().__init__() self.task_type = task_type self.project_name = project_name self.tags = tags self.task_name = task_name self.visualize = visualize self.complexity_metrics = complexity_metrics self.stream_logs = stream_logs self.temp_dir = tempfile.TemporaryDirectory() # Check if ClearML task already exists (e.g. in pipeline) if clearml.Task.current_task(): self.task = clearml.Task.current_task() else: self.task = clearml.Task.init( task_type=self.task_type, project_name=self.project_name, tags=self.tags, task_name=self.task_name, output_uri=True, ) self.logger = self.task.get_logger() warning = ( "The clearml callback is currently in beta and is subject to change " "based on updates to `langchain`. Please report any issues to " "https://github.com/allegroai/clearml/issues with the tag `langchain`." ) self.logger.report_text(warning, level=30, print_console=True) self.callback_columns: list = [] self.action_records: list = [] self.complexity_metrics = complexity_metrics self.visualize = visualize self.nlp = spacy.load("en_core_web_sm")
def _init_resp(self) -> Dict: return {k: None for k in self.callback_columns}
[docs] def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any ) -> None: """LLM启动时运行。""" self.step += 1 self.llm_starts += 1 self.starts += 1 resp = self._init_resp() resp.update({"action": "on_llm_start"}) resp.update(flatten_dict(serialized)) resp.update(self.get_custom_callback_meta()) for prompt in prompts: prompt_resp = deepcopy(resp) prompt_resp["prompts"] = prompt self.on_llm_start_records.append(prompt_resp) self.action_records.append(prompt_resp) if self.stream_logs: self.logger.report_text(prompt_resp)
[docs] def on_llm_new_token(self, token: str, **kwargs: Any) -> None: """当LLM生成一个新的令牌时运行。""" self.step += 1 self.llm_streams += 1 resp = self._init_resp() resp.update({"action": "on_llm_new_token", "token": token}) resp.update(self.get_custom_callback_meta()) self.on_llm_token_records.append(resp) self.action_records.append(resp) if self.stream_logs: self.logger.report_text(resp)
[docs] def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: """LLM 运行结束时运行。""" self.step += 1 self.llm_ends += 1 self.ends += 1 resp = self._init_resp() resp.update({"action": "on_llm_end"}) resp.update(flatten_dict(response.llm_output or {})) resp.update(self.get_custom_callback_meta()) for generations in response.generations: for generation in generations: generation_resp = deepcopy(resp) generation_resp.update(flatten_dict(generation.dict())) generation_resp.update(self.analyze_text(generation.text)) self.on_llm_end_records.append(generation_resp) self.action_records.append(generation_resp) if self.stream_logs: self.logger.report_text(generation_resp)
[docs] def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: """当LLM出现错误时运行。""" self.step += 1 self.errors += 1
[docs] def on_chain_start( self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any ) -> None: """当链开始运行时运行。""" self.step += 1 self.chain_starts += 1 self.starts += 1 resp = self._init_resp() resp.update({"action": "on_chain_start"}) resp.update(flatten_dict(serialized)) resp.update(self.get_custom_callback_meta()) chain_input = inputs.get("input", inputs.get("human_input")) if isinstance(chain_input, str): input_resp = deepcopy(resp) input_resp["input"] = chain_input self.on_chain_start_records.append(input_resp) self.action_records.append(input_resp) if self.stream_logs: self.logger.report_text(input_resp) elif isinstance(chain_input, list): for inp in chain_input: input_resp = deepcopy(resp) input_resp.update(inp) self.on_chain_start_records.append(input_resp) self.action_records.append(input_resp) if self.stream_logs: self.logger.report_text(input_resp) else: raise ValueError("Unexpected data format provided!")
[docs] def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: """当链结束运行时运行。""" self.step += 1 self.chain_ends += 1 self.ends += 1 resp = self._init_resp() resp.update( { "action": "on_chain_end", "outputs": outputs.get("output", outputs.get("text")), } ) resp.update(self.get_custom_callback_meta()) self.on_chain_end_records.append(resp) self.action_records.append(resp) if self.stream_logs: self.logger.report_text(resp)
[docs] def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: """当链式错误时运行。""" self.step += 1 self.errors += 1
[docs] def on_tool_start( self, serialized: Dict[str, Any], input_str: str, **kwargs: Any ) -> None: """当工具开始运行时运行。""" self.step += 1 self.tool_starts += 1 self.starts += 1 resp = self._init_resp() resp.update({"action": "on_tool_start", "input_str": input_str}) resp.update(flatten_dict(serialized)) resp.update(self.get_custom_callback_meta()) self.on_tool_start_records.append(resp) self.action_records.append(resp) if self.stream_logs: self.logger.report_text(resp)
[docs] def on_tool_end(self, output: Any, **kwargs: Any) -> None: """当工具运行结束时运行。""" output = str(output) self.step += 1 self.tool_ends += 1 self.ends += 1 resp = self._init_resp() resp.update({"action": "on_tool_end", "output": output}) resp.update(self.get_custom_callback_meta()) self.on_tool_end_records.append(resp) self.action_records.append(resp) if self.stream_logs: self.logger.report_text(resp)
[docs] def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: """当工具出现错误时运行。""" self.step += 1 self.errors += 1
[docs] def on_text(self, text: str, **kwargs: Any) -> None: """ 当代理程序结束时运行。 """ self.step += 1 self.text_ctr += 1 resp = self._init_resp() resp.update({"action": "on_text", "text": text}) resp.update(self.get_custom_callback_meta()) self.on_text_records.append(resp) self.action_records.append(resp) if self.stream_logs: self.logger.report_text(resp)
[docs] def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: """当代理程序运行结束时运行。""" self.step += 1 self.agent_ends += 1 self.ends += 1 resp = self._init_resp() resp.update( { "action": "on_agent_finish", "output": finish.return_values["output"], "log": finish.log, } ) resp.update(self.get_custom_callback_meta()) self.on_agent_finish_records.append(resp) self.action_records.append(resp) if self.stream_logs: self.logger.report_text(resp)
[docs] def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: """在代理程序上运行的操作。""" self.step += 1 self.tool_starts += 1 self.starts += 1 resp = self._init_resp() resp.update( { "action": "on_agent_action", "tool": action.tool, "tool_input": action.tool_input, "log": action.log, } ) resp.update(self.get_custom_callback_meta()) self.on_agent_action_records.append(resp) self.action_records.append(resp) if self.stream_logs: self.logger.report_text(resp)
[docs] def analyze_text(self, text: str) -> dict: """使用textstat和spacy分析文本。 参数: text (str): 要分析的文本。 返回: (dict): 包含复杂度指标的字典。 """ resp = {} textstat = import_textstat() spacy = import_spacy() if self.complexity_metrics: text_complexity_metrics = { "flesch_reading_ease": textstat.flesch_reading_ease(text), "flesch_kincaid_grade": textstat.flesch_kincaid_grade(text), "smog_index": textstat.smog_index(text), "coleman_liau_index": textstat.coleman_liau_index(text), "automated_readability_index": textstat.automated_readability_index( text ), "dale_chall_readability_score": textstat.dale_chall_readability_score( text ), "difficult_words": textstat.difficult_words(text), "linsear_write_formula": textstat.linsear_write_formula(text), "gunning_fog": textstat.gunning_fog(text), "text_standard": textstat.text_standard(text), "fernandez_huerta": textstat.fernandez_huerta(text), "szigriszt_pazos": textstat.szigriszt_pazos(text), "gutierrez_polini": textstat.gutierrez_polini(text), "crawford": textstat.crawford(text), "gulpease_index": textstat.gulpease_index(text), "osman": textstat.osman(text), } resp.update(text_complexity_metrics) if self.visualize and self.nlp and self.temp_dir.name is not None: doc = self.nlp(text) dep_out = spacy.displacy.render(doc, style="dep", jupyter=False, page=True) dep_output_path = Path( self.temp_dir.name, hash_string(f"dep-{text}") + ".html" ) dep_output_path.open("w", encoding="utf-8").write(dep_out) ent_out = spacy.displacy.render(doc, style="ent", jupyter=False, page=True) ent_output_path = Path( self.temp_dir.name, hash_string(f"ent-{text}") + ".html" ) ent_output_path.open("w", encoding="utf-8").write(ent_out) self.logger.report_media( "Dependencies Plot", text, local_path=dep_output_path ) self.logger.report_media("Entities Plot", text, local_path=ent_output_path) return resp
@staticmethod def _build_llm_df( base_df: pd.DataFrame, base_df_fields: Sequence, rename_map: Mapping ) -> pd.DataFrame: base_df_fields = [field for field in base_df_fields if field in base_df] rename_map = { map_entry_k: map_entry_v for map_entry_k, map_entry_v in rename_map.items() if map_entry_k in base_df_fields } llm_df = base_df[base_df_fields].dropna(axis=1) if rename_map: llm_df = llm_df.rename(rename_map, axis=1) return llm_df def _create_session_analysis_df(self) -> Any: """使用会话中的所有信息创建一个数据框。""" pd = import_pandas() on_llm_end_records_df = pd.DataFrame(self.on_llm_end_records) llm_input_prompts_df = ClearMLCallbackHandler._build_llm_df( base_df=on_llm_end_records_df, base_df_fields=["step", "prompts"] + (["name"] if "name" in on_llm_end_records_df else ["id"]), rename_map={"step": "prompt_step"}, ) complexity_metrics_columns = [] visualizations_columns: List = [] if self.complexity_metrics: complexity_metrics_columns = [ "flesch_reading_ease", "flesch_kincaid_grade", "smog_index", "coleman_liau_index", "automated_readability_index", "dale_chall_readability_score", "difficult_words", "linsear_write_formula", "gunning_fog", "text_standard", "fernandez_huerta", "szigriszt_pazos", "gutierrez_polini", "crawford", "gulpease_index", "osman", ] llm_outputs_df = ClearMLCallbackHandler._build_llm_df( on_llm_end_records_df, [ "step", "text", "token_usage_total_tokens", "token_usage_prompt_tokens", "token_usage_completion_tokens", ] + complexity_metrics_columns + visualizations_columns, {"step": "output_step", "text": "output"}, ) session_analysis_df = pd.concat([llm_input_prompts_df, llm_outputs_df], axis=1) return session_analysis_df
[docs] def flush_tracker( self, name: Optional[str] = None, langchain_asset: Any = None, finish: bool = False, ) -> None: """刷新追踪器并设置会话。 此后的所有内容将成为一个新表。 参数: name:迄今为止执行会话的名称,以便识别 langchain_asset:要保存的langchain资产。 finish:是否完成运行。 返回: 无。 """ pd = import_pandas() clearml = import_clearml() # Log the action records self.logger.report_table( "Action Records", name, table_plot=pd.DataFrame(self.action_records) ) # Session analysis session_analysis_df = self._create_session_analysis_df() self.logger.report_table( "Session Analysis", name, table_plot=session_analysis_df ) if self.stream_logs: self.logger.report_text( { "action_records": pd.DataFrame(self.action_records), "session_analysis": session_analysis_df, } ) if langchain_asset: langchain_asset_path = Path(self.temp_dir.name, "model.json") try: langchain_asset.save(langchain_asset_path) # Create output model and connect it to the task output_model = clearml.OutputModel( task=self.task, config_text=load_json(langchain_asset_path) ) output_model.update_weights( weights_filename=str(langchain_asset_path), auto_delete_file=False, target_filename=name, ) except ValueError: langchain_asset.save_agent(langchain_asset_path) output_model = clearml.OutputModel( task=self.task, config_text=load_json(langchain_asset_path) ) output_model.update_weights( weights_filename=str(langchain_asset_path), auto_delete_file=False, target_filename=name, ) except NotImplementedError as e: print("Could not save model.") # noqa: T201 print(repr(e)) # noqa: T201 pass # Cleanup after adding everything to ClearML self.task.flush(wait_for_uploads=True) self.temp_dir.cleanup() self.temp_dir = tempfile.TemporaryDirectory() self.reset_callback_meta() if finish: self.task.close()