import tempfile
from copy import deepcopy
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Sequence
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import Generation, LLMResult
from langchain_core.utils import guard_import
import langchain_community
from langchain_community.callbacks.utils import (
BaseMetadataCallbackHandler,
flatten_dict,
import_pandas,
import_spacy,
import_textstat,
)
LANGCHAIN_MODEL_NAME = "langchain-model"
[docs]def import_comet_ml() -> Any:
"""导入comet_ml,并在未安装时引发错误。"""
return guard_import("comet_ml")
def _get_experiment(
workspace: Optional[str] = None, project_name: Optional[str] = None
) -> Any:
comet_ml = import_comet_ml()
experiment = comet_ml.Experiment(
workspace=workspace,
project_name=project_name,
)
return experiment
def _fetch_text_complexity_metrics(text: str) -> dict:
textstat = import_textstat()
text_complexity_metrics = {
"flesch_reading_ease": textstat.flesch_reading_ease(text),
"flesch_kincaid_grade": textstat.flesch_kincaid_grade(text),
"smog_index": textstat.smog_index(text),
"coleman_liau_index": textstat.coleman_liau_index(text),
"automated_readability_index": textstat.automated_readability_index(text),
"dale_chall_readability_score": textstat.dale_chall_readability_score(text),
"difficult_words": textstat.difficult_words(text),
"linsear_write_formula": textstat.linsear_write_formula(text),
"gunning_fog": textstat.gunning_fog(text),
"text_standard": textstat.text_standard(text),
"fernandez_huerta": textstat.fernandez_huerta(text),
"szigriszt_pazos": textstat.szigriszt_pazos(text),
"gutierrez_polini": textstat.gutierrez_polini(text),
"crawford": textstat.crawford(text),
"gulpease_index": textstat.gulpease_index(text),
"osman": textstat.osman(text),
}
return text_complexity_metrics
def _summarize_metrics_for_generated_outputs(metrics: Sequence) -> dict:
pd = import_pandas()
metrics_df = pd.DataFrame(metrics)
metrics_summary = metrics_df.describe()
return metrics_summary.to_dict()
[docs]class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
"""回调处理程序,用于记录到Comet。
参数:
job_type (str): Comet_ml任务类型,例如“推断”、“测试”或“质控”
project_name (str): Comet_ml项目名称
tags (list): 要添加到任务的标签
task_name (str): Comet_ml任务的名称
visualize (bool): 是否可视化运行。
complexity_metrics (bool): 是否记录复杂度指标
stream_logs (bool): 是否将回调操作流式传输到Comet
此处理程序将利用相关的回调方法,并使用关于LLM运行状态的元数据格式化每个回调函数的输入,并将响应添加到{method}_records和action的记录列表中。然后将响应记录到Comet。"""
[docs] def __init__(
self,
task_type: Optional[str] = "inference",
workspace: Optional[str] = None,
project_name: Optional[str] = None,
tags: Optional[Sequence] = None,
name: Optional[str] = None,
visualizations: Optional[List[str]] = None,
complexity_metrics: bool = False,
custom_metrics: Optional[Callable] = None,
stream_logs: bool = True,
) -> None:
"""初始化回调处理程序。"""
self.comet_ml = import_comet_ml()
super().__init__()
self.task_type = task_type
self.workspace = workspace
self.project_name = project_name
self.tags = tags
self.visualizations = visualizations
self.complexity_metrics = complexity_metrics
self.custom_metrics = custom_metrics
self.stream_logs = stream_logs
self.temp_dir = tempfile.TemporaryDirectory()
self.experiment = _get_experiment(workspace, project_name)
self.experiment.log_other("Created from", "langchain")
if tags:
self.experiment.add_tags(tags)
self.name = name
if self.name:
self.experiment.set_name(self.name)
warning = (
"The comet_ml callback is currently in beta and is subject to change "
"based on updates to `langchain`. Please report any issues to "
"https://github.com/comet-ml/issue-tracking/issues with the tag "
"`langchain`."
)
self.comet_ml.LOGGER.warning(warning)
self.callback_columns: list = []
self.action_records: list = []
self.complexity_metrics = complexity_metrics
if self.visualizations:
spacy = import_spacy()
self.nlp = spacy.load("en_core_web_sm")
else:
self.nlp = None
def _init_resp(self) -> Dict:
return {k: None for k in self.callback_columns}
[docs] def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""LLM启动时运行。"""
self.step += 1
self.llm_starts += 1
self.starts += 1
metadata = self._init_resp()
metadata.update({"action": "on_llm_start"})
metadata.update(flatten_dict(serialized))
metadata.update(self.get_custom_callback_meta())
for prompt in prompts:
prompt_resp = deepcopy(metadata)
prompt_resp["prompts"] = prompt
self.on_llm_start_records.append(prompt_resp)
self.action_records.append(prompt_resp)
if self.stream_logs:
self._log_stream(prompt, metadata, self.step)
[docs] def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""当LLM生成一个新的令牌时运行。"""
self.step += 1
self.llm_streams += 1
resp = self._init_resp()
resp.update({"action": "on_llm_new_token", "token": token})
resp.update(self.get_custom_callback_meta())
self.action_records.append(resp)
[docs] def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""LLM 运行结束时运行。"""
self.step += 1
self.llm_ends += 1
self.ends += 1
metadata = self._init_resp()
metadata.update({"action": "on_llm_end"})
metadata.update(flatten_dict(response.llm_output or {}))
metadata.update(self.get_custom_callback_meta())
output_complexity_metrics = []
output_custom_metrics = []
for prompt_idx, generations in enumerate(response.generations):
for gen_idx, generation in enumerate(generations):
text = generation.text
generation_resp = deepcopy(metadata)
generation_resp.update(flatten_dict(generation.dict()))
complexity_metrics = self._get_complexity_metrics(text)
if complexity_metrics:
output_complexity_metrics.append(complexity_metrics)
generation_resp.update(complexity_metrics)
custom_metrics = self._get_custom_metrics(
generation, prompt_idx, gen_idx
)
if custom_metrics:
output_custom_metrics.append(custom_metrics)
generation_resp.update(custom_metrics)
if self.stream_logs:
self._log_stream(text, metadata, self.step)
self.action_records.append(generation_resp)
self.on_llm_end_records.append(generation_resp)
self._log_text_metrics(output_complexity_metrics, step=self.step)
self._log_text_metrics(output_custom_metrics, step=self.step)
[docs] def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""当LLM出现错误时运行。"""
self.step += 1
self.errors += 1
[docs] def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""当链开始运行时运行。"""
self.step += 1
self.chain_starts += 1
self.starts += 1
resp = self._init_resp()
resp.update({"action": "on_chain_start"})
resp.update(flatten_dict(serialized))
resp.update(self.get_custom_callback_meta())
for chain_input_key, chain_input_val in inputs.items():
if isinstance(chain_input_val, str):
input_resp = deepcopy(resp)
if self.stream_logs:
self._log_stream(chain_input_val, resp, self.step)
input_resp.update({chain_input_key: chain_input_val})
self.action_records.append(input_resp)
else:
self.comet_ml.LOGGER.warning(
f"Unexpected data format provided! "
f"Input Value for {chain_input_key} will not be logged"
)
[docs] def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""当链结束运行时运行。"""
self.step += 1
self.chain_ends += 1
self.ends += 1
resp = self._init_resp()
resp.update({"action": "on_chain_end"})
resp.update(self.get_custom_callback_meta())
for chain_output_key, chain_output_val in outputs.items():
if isinstance(chain_output_val, str):
output_resp = deepcopy(resp)
if self.stream_logs:
self._log_stream(chain_output_val, resp, self.step)
output_resp.update({chain_output_key: chain_output_val})
self.action_records.append(output_resp)
else:
self.comet_ml.LOGGER.warning(
f"Unexpected data format provided! "
f"Output Value for {chain_output_key} will not be logged"
)
[docs] def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""当链式错误时运行。"""
self.step += 1
self.errors += 1
[docs] def on_text(self, text: str, **kwargs: Any) -> None:
"""
当代理程序结束时运行。
"""
self.step += 1
self.text_ctr += 1
resp = self._init_resp()
resp.update({"action": "on_text"})
resp.update(self.get_custom_callback_meta())
if self.stream_logs:
self._log_stream(text, resp, self.step)
resp.update({"text": text})
self.action_records.append(resp)
[docs] def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""当代理程序运行结束时运行。"""
self.step += 1
self.agent_ends += 1
self.ends += 1
resp = self._init_resp()
output = finish.return_values["output"]
log = finish.log
resp.update({"action": "on_agent_finish", "log": log})
resp.update(self.get_custom_callback_meta())
if self.stream_logs:
self._log_stream(output, resp, self.step)
resp.update({"output": output})
self.action_records.append(resp)
[docs] def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""在代理程序上运行的操作。"""
self.step += 1
self.tool_starts += 1
self.starts += 1
tool = action.tool
tool_input = str(action.tool_input)
log = action.log
resp = self._init_resp()
resp.update({"action": "on_agent_action", "log": log, "tool": tool})
resp.update(self.get_custom_callback_meta())
if self.stream_logs:
self._log_stream(tool_input, resp, self.step)
resp.update({"tool_input": tool_input})
self.action_records.append(resp)
def _get_complexity_metrics(self, text: str) -> dict:
"""使用textstat计算文本复杂性指标。
参数:
text (str): 要分析的文本。
返回值:
(dict): 包含复杂性指标的字典。
"""
resp = {}
if self.complexity_metrics:
text_complexity_metrics = _fetch_text_complexity_metrics(text)
resp.update(text_complexity_metrics)
return resp
def _get_custom_metrics(
self, generation: Generation, prompt_idx: int, gen_idx: int
) -> dict:
"""为LLM生成的输出计算自定义指标
参数:
generation (LLMResult): LLM生成的输出
prompt_idx (int): 输入提示的列表索引
gen_idx (int): 生成输出的列表索引
返回:
dict: 包含自定义指标的字典。
"""
resp = {}
if self.custom_metrics:
custom_metrics = self.custom_metrics(generation, prompt_idx, gen_idx)
resp.update(custom_metrics)
return resp
[docs] def flush_tracker(
self,
langchain_asset: Any = None,
task_type: Optional[str] = "inference",
workspace: Optional[str] = None,
project_name: Optional[str] = "comet-langchain-demo",
tags: Optional[Sequence] = None,
name: Optional[str] = None,
visualizations: Optional[List[str]] = None,
complexity_metrics: bool = False,
custom_metrics: Optional[Callable] = None,
finish: bool = False,
reset: bool = False,
) -> None:
"""刷新追踪器并设置会话。
此后的所有内容将成为一个新表。
参数:
name:迄今为止执行会话的名称,以便识别
langchain_asset:要保存的langchain资产。
finish:是否完成运行。
返回:
无。
"""
self._log_session(langchain_asset)
if langchain_asset:
try:
self._log_model(langchain_asset)
except Exception:
self.comet_ml.LOGGER.error(
"Failed to export agent or LLM to Comet",
exc_info=True,
extra={"show_traceback": True},
)
if finish:
self.experiment.end()
if reset:
self._reset(
task_type,
workspace,
project_name,
tags,
name,
visualizations,
complexity_metrics,
custom_metrics,
)
def _log_stream(self, prompt: str, metadata: dict, step: int) -> None:
self.experiment.log_text(prompt, metadata=metadata, step=step)
def _log_model(self, langchain_asset: Any) -> None:
model_parameters = self._get_llm_parameters(langchain_asset)
self.experiment.log_parameters(model_parameters, prefix="model")
langchain_asset_path = Path(self.temp_dir.name, "model.json")
model_name = self.name if self.name else LANGCHAIN_MODEL_NAME
try:
if hasattr(langchain_asset, "save"):
langchain_asset.save(langchain_asset_path)
self.experiment.log_model(model_name, str(langchain_asset_path))
except (ValueError, AttributeError, NotImplementedError) as e:
if hasattr(langchain_asset, "save_agent"):
langchain_asset.save_agent(langchain_asset_path)
self.experiment.log_model(model_name, str(langchain_asset_path))
else:
self.comet_ml.LOGGER.error(
f"{e}"
" Could not save Langchain Asset "
f"for {langchain_asset.__class__.__name__}"
)
def _log_session(self, langchain_asset: Optional[Any] = None) -> None:
try:
llm_session_df = self._create_session_analysis_dataframe(langchain_asset)
# Log the cleaned dataframe as a table
self.experiment.log_table("langchain-llm-session.csv", llm_session_df)
except Exception:
self.comet_ml.LOGGER.warning(
"Failed to log session data to Comet",
exc_info=True,
extra={"show_traceback": True},
)
try:
metadata = {"langchain_version": str(langchain_community.__version__)}
# Log the langchain low-level records as a JSON file directly
self.experiment.log_asset_data(
self.action_records, "langchain-action_records.json", metadata=metadata
)
except Exception:
self.comet_ml.LOGGER.warning(
"Failed to log session data to Comet",
exc_info=True,
extra={"show_traceback": True},
)
try:
self._log_visualizations(llm_session_df)
except Exception:
self.comet_ml.LOGGER.warning(
"Failed to log visualizations to Comet",
exc_info=True,
extra={"show_traceback": True},
)
def _log_text_metrics(self, metrics: Sequence[dict], step: int) -> None:
if not metrics:
return
metrics_summary = _summarize_metrics_for_generated_outputs(metrics)
for key, value in metrics_summary.items():
self.experiment.log_metrics(value, prefix=key, step=step)
def _log_visualizations(self, session_df: Any) -> None:
if not (self.visualizations and self.nlp):
return
spacy = import_spacy()
prompts = session_df["prompts"].tolist()
outputs = session_df["text"].tolist()
for idx, (prompt, output) in enumerate(zip(prompts, outputs)):
doc = self.nlp(output)
sentence_spans = list(doc.sents)
for visualization in self.visualizations:
try:
html = spacy.displacy.render(
sentence_spans,
style=visualization,
options={"compact": True},
jupyter=False,
page=True,
)
self.experiment.log_asset_data(
html,
name=f"langchain-viz-{visualization}-{idx}.html",
metadata={"prompt": prompt},
step=idx,
)
except Exception as e:
self.comet_ml.LOGGER.warning(
e, exc_info=True, extra={"show_traceback": True}
)
return
def _reset(
self,
task_type: Optional[str] = None,
workspace: Optional[str] = None,
project_name: Optional[str] = None,
tags: Optional[Sequence] = None,
name: Optional[str] = None,
visualizations: Optional[List[str]] = None,
complexity_metrics: bool = False,
custom_metrics: Optional[Callable] = None,
) -> None:
_task_type = task_type if task_type else self.task_type
_workspace = workspace if workspace else self.workspace
_project_name = project_name if project_name else self.project_name
_tags = tags if tags else self.tags
_name = name if name else self.name
_visualizations = visualizations if visualizations else self.visualizations
_complexity_metrics = (
complexity_metrics if complexity_metrics else self.complexity_metrics
)
_custom_metrics = custom_metrics if custom_metrics else self.custom_metrics
self.__init__( # type: ignore
task_type=_task_type,
workspace=_workspace,
project_name=_project_name,
tags=_tags,
name=_name,
visualizations=_visualizations,
complexity_metrics=_complexity_metrics,
custom_metrics=_custom_metrics,
)
self.reset_callback_meta()
self.temp_dir = tempfile.TemporaryDirectory()
def _create_session_analysis_dataframe(self, langchain_asset: Any = None) -> dict:
pd = import_pandas()
llm_parameters = self._get_llm_parameters(langchain_asset)
num_generations_per_prompt = llm_parameters.get("n", 1)
llm_start_records_df = pd.DataFrame(self.on_llm_start_records)
# Repeat each input row based on the number of outputs generated per prompt
llm_start_records_df = llm_start_records_df.loc[
llm_start_records_df.index.repeat(num_generations_per_prompt)
].reset_index(drop=True)
llm_end_records_df = pd.DataFrame(self.on_llm_end_records)
llm_session_df = pd.merge(
llm_start_records_df,
llm_end_records_df,
left_index=True,
right_index=True,
suffixes=["_llm_start", "_llm_end"],
)
return llm_session_df
def _get_llm_parameters(self, langchain_asset: Any = None) -> dict:
if not langchain_asset:
return {}
try:
if hasattr(langchain_asset, "agent"):
llm_parameters = langchain_asset.agent.llm_chain.llm.dict()
elif hasattr(langchain_asset, "llm_chain"):
llm_parameters = langchain_asset.llm_chain.llm.dict()
elif hasattr(langchain_asset, "llm"):
llm_parameters = langchain_asset.llm.dict()
else:
llm_parameters = langchain_asset.dict()
except Exception:
return {}
return llm_parameters