from __future__ import annotations
import tempfile
from copy import deepcopy
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Sequence
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult
from langchain_core.utils import guard_import
from langchain_community.callbacks.utils import (
BaseMetadataCallbackHandler,
flatten_dict,
hash_string,
import_pandas,
import_spacy,
import_textstat,
load_json,
)
if TYPE_CHECKING:
import pandas as pd
[docs]def import_clearml() -> Any:
"""导入clearml python包,并在未安装时引发错误。"""
return guard_import("clearml")
[docs]class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
"""回调处理程序,用于记录到ClearML。
参数:
job_type (str): ClearML任务类型,如“推断”、“测试”或“质控”
project_name (str): ClearML项目名称
tags (list): 要添加到任务的标签
task_name (str): ClearML任务名称
visualize (bool): 是否可视化运行。
complexity_metrics (bool): 是否记录复杂度指标
stream_logs (bool): 是否将回调操作流式传输到ClearML
此处理程序将利用关联的回调方法和格式化每个回调函数的输入,其中包含有关LLM运行状态的元数据,并将响应添加到{method}_records和操作的记录列表中。然后将响应记录到ClearML控制台。"""
[docs] def __init__(
self,
task_type: Optional[str] = "inference",
project_name: Optional[str] = "langchain_callback_demo",
tags: Optional[Sequence] = None,
task_name: Optional[str] = None,
visualize: bool = False,
complexity_metrics: bool = False,
stream_logs: bool = False,
) -> None:
"""初始化回调处理程序。"""
clearml = import_clearml()
spacy = import_spacy()
super().__init__()
self.task_type = task_type
self.project_name = project_name
self.tags = tags
self.task_name = task_name
self.visualize = visualize
self.complexity_metrics = complexity_metrics
self.stream_logs = stream_logs
self.temp_dir = tempfile.TemporaryDirectory()
# Check if ClearML task already exists (e.g. in pipeline)
if clearml.Task.current_task():
self.task = clearml.Task.current_task()
else:
self.task = clearml.Task.init(
task_type=self.task_type,
project_name=self.project_name,
tags=self.tags,
task_name=self.task_name,
output_uri=True,
)
self.logger = self.task.get_logger()
warning = (
"The clearml callback is currently in beta and is subject to change "
"based on updates to `langchain`. Please report any issues to "
"https://github.com/allegroai/clearml/issues with the tag `langchain`."
)
self.logger.report_text(warning, level=30, print_console=True)
self.callback_columns: list = []
self.action_records: list = []
self.complexity_metrics = complexity_metrics
self.visualize = visualize
self.nlp = spacy.load("en_core_web_sm")
def _init_resp(self) -> Dict:
return {k: None for k in self.callback_columns}
[docs] def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""LLM启动时运行。"""
self.step += 1
self.llm_starts += 1
self.starts += 1
resp = self._init_resp()
resp.update({"action": "on_llm_start"})
resp.update(flatten_dict(serialized))
resp.update(self.get_custom_callback_meta())
for prompt in prompts:
prompt_resp = deepcopy(resp)
prompt_resp["prompts"] = prompt
self.on_llm_start_records.append(prompt_resp)
self.action_records.append(prompt_resp)
if self.stream_logs:
self.logger.report_text(prompt_resp)
[docs] def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""当LLM生成一个新的令牌时运行。"""
self.step += 1
self.llm_streams += 1
resp = self._init_resp()
resp.update({"action": "on_llm_new_token", "token": token})
resp.update(self.get_custom_callback_meta())
self.on_llm_token_records.append(resp)
self.action_records.append(resp)
if self.stream_logs:
self.logger.report_text(resp)
[docs] def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""LLM 运行结束时运行。"""
self.step += 1
self.llm_ends += 1
self.ends += 1
resp = self._init_resp()
resp.update({"action": "on_llm_end"})
resp.update(flatten_dict(response.llm_output or {}))
resp.update(self.get_custom_callback_meta())
for generations in response.generations:
for generation in generations:
generation_resp = deepcopy(resp)
generation_resp.update(flatten_dict(generation.dict()))
generation_resp.update(self.analyze_text(generation.text))
self.on_llm_end_records.append(generation_resp)
self.action_records.append(generation_resp)
if self.stream_logs:
self.logger.report_text(generation_resp)
[docs] def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""当LLM出现错误时运行。"""
self.step += 1
self.errors += 1
[docs] def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""当链开始运行时运行。"""
self.step += 1
self.chain_starts += 1
self.starts += 1
resp = self._init_resp()
resp.update({"action": "on_chain_start"})
resp.update(flatten_dict(serialized))
resp.update(self.get_custom_callback_meta())
chain_input = inputs.get("input", inputs.get("human_input"))
if isinstance(chain_input, str):
input_resp = deepcopy(resp)
input_resp["input"] = chain_input
self.on_chain_start_records.append(input_resp)
self.action_records.append(input_resp)
if self.stream_logs:
self.logger.report_text(input_resp)
elif isinstance(chain_input, list):
for inp in chain_input:
input_resp = deepcopy(resp)
input_resp.update(inp)
self.on_chain_start_records.append(input_resp)
self.action_records.append(input_resp)
if self.stream_logs:
self.logger.report_text(input_resp)
else:
raise ValueError("Unexpected data format provided!")
[docs] def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""当链结束运行时运行。"""
self.step += 1
self.chain_ends += 1
self.ends += 1
resp = self._init_resp()
resp.update(
{
"action": "on_chain_end",
"outputs": outputs.get("output", outputs.get("text")),
}
)
resp.update(self.get_custom_callback_meta())
self.on_chain_end_records.append(resp)
self.action_records.append(resp)
if self.stream_logs:
self.logger.report_text(resp)
[docs] def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""当链式错误时运行。"""
self.step += 1
self.errors += 1
[docs] def on_text(self, text: str, **kwargs: Any) -> None:
"""
当代理程序结束时运行。
"""
self.step += 1
self.text_ctr += 1
resp = self._init_resp()
resp.update({"action": "on_text", "text": text})
resp.update(self.get_custom_callback_meta())
self.on_text_records.append(resp)
self.action_records.append(resp)
if self.stream_logs:
self.logger.report_text(resp)
[docs] def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""当代理程序运行结束时运行。"""
self.step += 1
self.agent_ends += 1
self.ends += 1
resp = self._init_resp()
resp.update(
{
"action": "on_agent_finish",
"output": finish.return_values["output"],
"log": finish.log,
}
)
resp.update(self.get_custom_callback_meta())
self.on_agent_finish_records.append(resp)
self.action_records.append(resp)
if self.stream_logs:
self.logger.report_text(resp)
[docs] def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""在代理程序上运行的操作。"""
self.step += 1
self.tool_starts += 1
self.starts += 1
resp = self._init_resp()
resp.update(
{
"action": "on_agent_action",
"tool": action.tool,
"tool_input": action.tool_input,
"log": action.log,
}
)
resp.update(self.get_custom_callback_meta())
self.on_agent_action_records.append(resp)
self.action_records.append(resp)
if self.stream_logs:
self.logger.report_text(resp)
[docs] def analyze_text(self, text: str) -> dict:
"""使用textstat和spacy分析文本。
参数:
text (str): 要分析的文本。
返回:
(dict): 包含复杂度指标的字典。
"""
resp = {}
textstat = import_textstat()
spacy = import_spacy()
if self.complexity_metrics:
text_complexity_metrics = {
"flesch_reading_ease": textstat.flesch_reading_ease(text),
"flesch_kincaid_grade": textstat.flesch_kincaid_grade(text),
"smog_index": textstat.smog_index(text),
"coleman_liau_index": textstat.coleman_liau_index(text),
"automated_readability_index": textstat.automated_readability_index(
text
),
"dale_chall_readability_score": textstat.dale_chall_readability_score(
text
),
"difficult_words": textstat.difficult_words(text),
"linsear_write_formula": textstat.linsear_write_formula(text),
"gunning_fog": textstat.gunning_fog(text),
"text_standard": textstat.text_standard(text),
"fernandez_huerta": textstat.fernandez_huerta(text),
"szigriszt_pazos": textstat.szigriszt_pazos(text),
"gutierrez_polini": textstat.gutierrez_polini(text),
"crawford": textstat.crawford(text),
"gulpease_index": textstat.gulpease_index(text),
"osman": textstat.osman(text),
}
resp.update(text_complexity_metrics)
if self.visualize and self.nlp and self.temp_dir.name is not None:
doc = self.nlp(text)
dep_out = spacy.displacy.render(doc, style="dep", jupyter=False, page=True)
dep_output_path = Path(
self.temp_dir.name, hash_string(f"dep-{text}") + ".html"
)
dep_output_path.open("w", encoding="utf-8").write(dep_out)
ent_out = spacy.displacy.render(doc, style="ent", jupyter=False, page=True)
ent_output_path = Path(
self.temp_dir.name, hash_string(f"ent-{text}") + ".html"
)
ent_output_path.open("w", encoding="utf-8").write(ent_out)
self.logger.report_media(
"Dependencies Plot", text, local_path=dep_output_path
)
self.logger.report_media("Entities Plot", text, local_path=ent_output_path)
return resp
@staticmethod
def _build_llm_df(
base_df: pd.DataFrame, base_df_fields: Sequence, rename_map: Mapping
) -> pd.DataFrame:
base_df_fields = [field for field in base_df_fields if field in base_df]
rename_map = {
map_entry_k: map_entry_v
for map_entry_k, map_entry_v in rename_map.items()
if map_entry_k in base_df_fields
}
llm_df = base_df[base_df_fields].dropna(axis=1)
if rename_map:
llm_df = llm_df.rename(rename_map, axis=1)
return llm_df
def _create_session_analysis_df(self) -> Any:
"""使用会话中的所有信息创建一个数据框。"""
pd = import_pandas()
on_llm_end_records_df = pd.DataFrame(self.on_llm_end_records)
llm_input_prompts_df = ClearMLCallbackHandler._build_llm_df(
base_df=on_llm_end_records_df,
base_df_fields=["step", "prompts"]
+ (["name"] if "name" in on_llm_end_records_df else ["id"]),
rename_map={"step": "prompt_step"},
)
complexity_metrics_columns = []
visualizations_columns: List = []
if self.complexity_metrics:
complexity_metrics_columns = [
"flesch_reading_ease",
"flesch_kincaid_grade",
"smog_index",
"coleman_liau_index",
"automated_readability_index",
"dale_chall_readability_score",
"difficult_words",
"linsear_write_formula",
"gunning_fog",
"text_standard",
"fernandez_huerta",
"szigriszt_pazos",
"gutierrez_polini",
"crawford",
"gulpease_index",
"osman",
]
llm_outputs_df = ClearMLCallbackHandler._build_llm_df(
on_llm_end_records_df,
[
"step",
"text",
"token_usage_total_tokens",
"token_usage_prompt_tokens",
"token_usage_completion_tokens",
]
+ complexity_metrics_columns
+ visualizations_columns,
{"step": "output_step", "text": "output"},
)
session_analysis_df = pd.concat([llm_input_prompts_df, llm_outputs_df], axis=1)
return session_analysis_df
[docs] def flush_tracker(
self,
name: Optional[str] = None,
langchain_asset: Any = None,
finish: bool = False,
) -> None:
"""刷新追踪器并设置会话。
此后的所有内容将成为一个新表。
参数:
name:迄今为止执行会话的名称,以便识别
langchain_asset:要保存的langchain资产。
finish:是否完成运行。
返回:
无。
"""
pd = import_pandas()
clearml = import_clearml()
# Log the action records
self.logger.report_table(
"Action Records", name, table_plot=pd.DataFrame(self.action_records)
)
# Session analysis
session_analysis_df = self._create_session_analysis_df()
self.logger.report_table(
"Session Analysis", name, table_plot=session_analysis_df
)
if self.stream_logs:
self.logger.report_text(
{
"action_records": pd.DataFrame(self.action_records),
"session_analysis": session_analysis_df,
}
)
if langchain_asset:
langchain_asset_path = Path(self.temp_dir.name, "model.json")
try:
langchain_asset.save(langchain_asset_path)
# Create output model and connect it to the task
output_model = clearml.OutputModel(
task=self.task, config_text=load_json(langchain_asset_path)
)
output_model.update_weights(
weights_filename=str(langchain_asset_path),
auto_delete_file=False,
target_filename=name,
)
except ValueError:
langchain_asset.save_agent(langchain_asset_path)
output_model = clearml.OutputModel(
task=self.task, config_text=load_json(langchain_asset_path)
)
output_model.update_weights(
weights_filename=str(langchain_asset_path),
auto_delete_file=False,
target_filename=name,
)
except NotImplementedError as e:
print("Could not save model.") # noqa: T201
print(repr(e)) # noqa: T201
pass
# Cleanup after adding everything to ClearML
self.task.flush(wait_for_uploads=True)
self.temp_dir.cleanup()
self.temp_dir = tempfile.TemporaryDirectory()
self.reset_callback_meta()
if finish:
self.task.close()