Source code for langchain_community.callbacks.mlflow_callback

import logging
import os
import random
import string
import tempfile
import traceback
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Union

from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.documents import Document
from langchain_core.outputs import LLMResult
from langchain_core.utils import get_from_dict_or_env, guard_import

from langchain_community.callbacks.utils import (
    BaseMetadataCallbackHandler,
    flatten_dict,
    hash_string,
    import_pandas,
    import_spacy,
    import_textstat,
)

logger = logging.getLogger(__name__)


[docs]def import_mlflow() -> Any: """导入mlflow python包,并在未安装时引发错误。""" return guard_import("mlflow")
[docs]def mlflow_callback_metrics() -> List[str]: """获取要记录到MLFlow的指标。""" return [ "step", "starts", "ends", "errors", "text_ctr", "chain_starts", "chain_ends", "llm_starts", "llm_ends", "llm_streams", "tool_starts", "tool_ends", "agent_ends", "retriever_starts", "retriever_ends", ]
[docs]def get_text_complexity_metrics() -> List[str]: """从textstat获取文本复杂度指标。""" return [ "flesch_reading_ease", "flesch_kincaid_grade", "smog_index", "coleman_liau_index", "automated_readability_index", "dale_chall_readability_score", "difficult_words", "linsear_write_formula", "gunning_fog", # "text_standard" "fernandez_huerta", "szigriszt_pazos", "gutierrez_polini", "crawford", "gulpease_index", "osman", ]
[docs]def analyze_text( text: str, nlp: Any = None, textstat: Any = None, ) -> dict: """使用textstat和spacy分析文本。 参数: text (str): 要分析的文本。 nlp (spacy.lang): 用于可视化的spacy语言模型。 textstat: 用于计算复杂度指标的textstat库。 返回: (dict): 包含复杂度指标和可视化文件序列化为HTML字符串的字典。 """ resp: Dict[str, Any] = {} if textstat is not None: text_complexity_metrics = { key: getattr(textstat, key)(text) for key in get_text_complexity_metrics() } resp.update({"text_complexity_metrics": text_complexity_metrics}) resp.update(text_complexity_metrics) if nlp is not None: spacy = import_spacy() doc = nlp(text) dep_out = spacy.displacy.render(doc, style="dep", jupyter=False, page=True) ent_out = spacy.displacy.render(doc, style="ent", jupyter=False, page=True) text_visualizations = { "dependency_tree": dep_out, "entities": ent_out, } resp.update(text_visualizations) return resp
[docs]def construct_html_from_prompt_and_generation(prompt: str, generation: str) -> Any: """根据提示和生成内容构建一个html元素。 参数: prompt(str):提示。 generation(str):生成内容。 返回: (str):html字符串。 """ formatted_prompt = prompt.replace("\n", "<br>") formatted_generation = generation.replace("\n", "<br>") return f""" <p style="color:black;">{formatted_prompt}:</p> <blockquote> <p style="color:green;"> {formatted_generation} </p> </blockquote> """
[docs]class MlflowLogger: """回调处理程序,用于将指标和工件记录到mlflow服务器。 参数: name(str):运行的名称。 experiment(str):实验的名称。 tags(dict):要附加到运行的标签。 tracking_uri(str):MLflow跟踪服务器的URI。 该处理程序实现了初始化、记录指标和工件到mlflow服务器的辅助函数。"""
[docs] def __init__(self, **kwargs: Any): self.mlflow = import_mlflow() if "DATABRICKS_RUNTIME_VERSION" in os.environ: self.mlflow.set_tracking_uri("databricks") self.mlf_expid = self.mlflow.tracking.fluent._get_experiment_id() self.mlf_exp = self.mlflow.get_experiment(self.mlf_expid) else: tracking_uri = get_from_dict_or_env( kwargs, "tracking_uri", "MLFLOW_TRACKING_URI", "" ) self.mlflow.set_tracking_uri(tracking_uri) if run_id := kwargs.get("run_id"): self.mlf_expid = self.mlflow.get_run(run_id).info.experiment_id else: # User can set other env variables described here # > https://www.mlflow.org/docs/latest/tracking.html#logging-to-a-tracking-server experiment_name = get_from_dict_or_env( kwargs, "experiment_name", "MLFLOW_EXPERIMENT_NAME" ) self.mlf_exp = self.mlflow.get_experiment_by_name(experiment_name) if self.mlf_exp is not None: self.mlf_expid = self.mlf_exp.experiment_id else: self.mlf_expid = self.mlflow.create_experiment(experiment_name) self.start_run( kwargs["run_name"], kwargs["run_tags"], kwargs.get("run_id", None) ) self.dir = kwargs.get("artifacts_dir", "")
[docs] def start_run( self, name: str, tags: Dict[str, str], run_id: Optional[str] = None ) -> None: """如果提供了run_id,则将重用具有给定run_id的运行。 否则,它将启动一个新的运行,为名称自动生成随机后缀。 """ if run_id is None: if name.endswith("-%"): rname = "".join( random.choices(string.ascii_uppercase + string.digits, k=7) ) name = name[:-1] + rname run = self.mlflow.MlflowClient().create_run( self.mlf_expid, run_name=name, tags=tags ) run_id = run.info.run_id self.run_id = run_id
[docs] def finish_run(self) -> None: """完成运行。""" self.mlflow.end_run()
[docs] def metric(self, key: str, value: float) -> None: """将指标记录到mlflow服务器。""" self.mlflow.log_metric(key, value, run_id=self.run_id)
[docs] def metrics( self, data: Union[Dict[str, float], Dict[str, int]], step: Optional[int] = 0 ) -> None: """记录输入字典中的所有指标。""" self.mlflow.log_metrics(data, run_id=self.run_id)
[docs] def jsonf(self, data: Dict[str, Any], filename: str) -> None: """将输入数据记录为json文件工件。""" self.mlflow.log_dict( data, os.path.join(self.dir, f"{filename}.json"), run_id=self.run_id )
[docs] def table(self, name: str, dataframe: Any) -> None: """将输入的pandas数据框记录为HTML表格。""" self.html(dataframe.to_html(), f"table_{name}")
[docs] def html(self, html: str, filename: str) -> None: """将输入的HTML字符串记录为HTML文件。""" self.mlflow.log_text( html, os.path.join(self.dir, f"{filename}.html"), run_id=self.run_id )
[docs] def text(self, text: str, filename: str) -> None: """将输入文本记录为文本文件工件。""" self.mlflow.log_text( text, os.path.join(self.dir, f"{filename}.txt"), run_id=self.run_id )
[docs] def artifact(self, path: str) -> None: """将给定路径中的文件上传为构件。""" self.mlflow.log_artifact(path, run_id=self.run_id)
[docs] def langchain_artifact(self, chain: Any) -> None: self.mlflow.langchain.log_model(chain, "langchain-model", run_id=self.run_id)
[docs]class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): """回调处理程序,将指标和工件记录到mlflow服务器。 参数: name (str): 运行的名称。 experiment (str): 实验的名称。 tags (dict): 要附加到运行的标签。 tracking_uri (str): MLflow跟踪服务器的URI。 此处理程序将利用关联的回调方法,并使用有关LLM运行状态的元数据格式化每个回调函数的输入,并将响应添加到{method}_records和操作的记录列表中。然后将响应记录到mlflow服务器。"""
[docs] def __init__( self, name: Optional[str] = "langchainrun-%", experiment: Optional[str] = "langchain", tags: Optional[Dict] = None, tracking_uri: Optional[str] = None, run_id: Optional[str] = None, artifacts_dir: str = "", ) -> None: """初始化回调处理程序。""" import_pandas() import_mlflow() super().__init__() self.name = name self.experiment = experiment self.tags = tags or {} self.tracking_uri = tracking_uri self.run_id = run_id self.artifacts_dir = artifacts_dir self.temp_dir = tempfile.TemporaryDirectory() self.mlflg = MlflowLogger( tracking_uri=self.tracking_uri, experiment_name=self.experiment, run_name=self.name, run_tags=self.tags, run_id=self.run_id, artifacts_dir=self.artifacts_dir, ) self.action_records: list = [] self.nlp = None try: spacy = import_spacy() except ImportError as e: logger.warning(e.msg) else: try: self.nlp = spacy.load("en_core_web_sm") except OSError: logger.warning( "Run `python -m spacy download en_core_web_sm` " "to download en_core_web_sm model for text visualization." ) try: self.textstat = import_textstat() except ImportError as e: logger.warning(e.msg) self.textstat = None self.metrics = {key: 0 for key in mlflow_callback_metrics()} self.records: Dict[str, Any] = { "on_llm_start_records": [], "on_llm_token_records": [], "on_llm_end_records": [], "on_chain_start_records": [], "on_chain_end_records": [], "on_tool_start_records": [], "on_tool_end_records": [], "on_text_records": [], "on_agent_finish_records": [], "on_agent_action_records": [], "on_retriever_start_records": [], "on_retriever_end_records": [], "action_records": [], }
def _reset(self) -> None: for k, v in self.metrics.items(): self.metrics[k] = 0 for k, v in self.records.items(): self.records[k] = []
[docs] def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any ) -> None: """LLM启动时运行。""" self.metrics["step"] += 1 self.metrics["llm_starts"] += 1 self.metrics["starts"] += 1 llm_starts = self.metrics["llm_starts"] resp: Dict[str, Any] = {} resp.update({"action": "on_llm_start"}) resp.update(flatten_dict(serialized)) resp.update(self.metrics) self.mlflg.metrics(self.metrics, step=self.metrics["step"]) for idx, prompt in enumerate(prompts): prompt_resp = deepcopy(resp) prompt_resp["prompt"] = prompt self.records["on_llm_start_records"].append(prompt_resp) self.records["action_records"].append(prompt_resp) self.mlflg.jsonf(prompt_resp, f"llm_start_{llm_starts}_prompt_{idx}")
[docs] def on_llm_new_token(self, token: str, **kwargs: Any) -> None: """当LLM生成一个新的令牌时运行。""" self.metrics["step"] += 1 self.metrics["llm_streams"] += 1 llm_streams = self.metrics["llm_streams"] resp: Dict[str, Any] = {} resp.update({"action": "on_llm_new_token", "token": token}) resp.update(self.metrics) self.mlflg.metrics(self.metrics, step=self.metrics["step"]) self.records["on_llm_token_records"].append(resp) self.records["action_records"].append(resp) self.mlflg.jsonf(resp, f"llm_new_tokens_{llm_streams}")
[docs] def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: """LLM 运行结束时运行。""" self.metrics["step"] += 1 self.metrics["llm_ends"] += 1 self.metrics["ends"] += 1 llm_ends = self.metrics["llm_ends"] resp: Dict[str, Any] = {} resp.update({"action": "on_llm_end"}) resp.update(flatten_dict(response.llm_output or {})) resp.update(self.metrics) self.mlflg.metrics(self.metrics, step=self.metrics["step"]) for generations in response.generations: for idx, generation in enumerate(generations): generation_resp = deepcopy(resp) generation_resp.update(flatten_dict(generation.dict())) generation_resp.update( analyze_text( generation.text, nlp=self.nlp, textstat=self.textstat, ) ) if "text_complexity_metrics" in generation_resp: complexity_metrics: Dict[str, float] = generation_resp.pop( "text_complexity_metrics" ) self.mlflg.metrics( complexity_metrics, step=self.metrics["step"], ) self.records["on_llm_end_records"].append(generation_resp) self.records["action_records"].append(generation_resp) self.mlflg.jsonf(resp, f"llm_end_{llm_ends}_generation_{idx}") if "dependency_tree" in generation_resp: dependency_tree = generation_resp["dependency_tree"] self.mlflg.html( dependency_tree, "dep-" + hash_string(generation.text) ) if "entities" in generation_resp: entities = generation_resp["entities"] self.mlflg.html(entities, "ent-" + hash_string(generation.text))
[docs] def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: """当LLM出现错误时运行。""" self.metrics["step"] += 1 self.metrics["errors"] += 1
[docs] def on_chain_start( self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any ) -> None: """当链开始运行时运行。""" self.metrics["step"] += 1 self.metrics["chain_starts"] += 1 self.metrics["starts"] += 1 chain_starts = self.metrics["chain_starts"] resp: Dict[str, Any] = {} resp.update({"action": "on_chain_start"}) resp.update(flatten_dict(serialized)) resp.update(self.metrics) self.mlflg.metrics(self.metrics, step=self.metrics["step"]) if isinstance(inputs, dict): chain_input = ",".join([f"{k}={v}" for k, v in inputs.items()]) elif isinstance(inputs, list): chain_input = ",".join([str(input) for input in inputs]) else: chain_input = str(inputs) input_resp = deepcopy(resp) input_resp["inputs"] = chain_input self.records["on_chain_start_records"].append(input_resp) self.records["action_records"].append(input_resp) self.mlflg.jsonf(input_resp, f"chain_start_{chain_starts}")
[docs] def on_chain_end( self, outputs: Union[Dict[str, Any], str, List[str]], **kwargs: Any ) -> None: """当链结束运行时运行。""" self.metrics["step"] += 1 self.metrics["chain_ends"] += 1 self.metrics["ends"] += 1 chain_ends = self.metrics["chain_ends"] resp: Dict[str, Any] = {} if isinstance(outputs, dict): chain_output = ",".join([f"{k}={v}" for k, v in outputs.items()]) elif isinstance(outputs, list): chain_output = ",".join(map(str, outputs)) else: chain_output = str(outputs) resp.update({"action": "on_chain_end", "outputs": chain_output}) resp.update(self.metrics) self.mlflg.metrics(self.metrics, step=self.metrics["step"]) self.records["on_chain_end_records"].append(resp) self.records["action_records"].append(resp) self.mlflg.jsonf(resp, f"chain_end_{chain_ends}")
[docs] def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: """当链式错误时运行。""" self.metrics["step"] += 1 self.metrics["errors"] += 1
[docs] def on_tool_start( self, serialized: Dict[str, Any], input_str: str, **kwargs: Any ) -> None: """当工具开始运行时运行。""" self.metrics["step"] += 1 self.metrics["tool_starts"] += 1 self.metrics["starts"] += 1 tool_starts = self.metrics["tool_starts"] resp: Dict[str, Any] = {} resp.update({"action": "on_tool_start", "input_str": input_str}) resp.update(flatten_dict(serialized)) resp.update(self.metrics) self.mlflg.metrics(self.metrics, step=self.metrics["step"]) self.records["on_tool_start_records"].append(resp) self.records["action_records"].append(resp) self.mlflg.jsonf(resp, f"tool_start_{tool_starts}")
[docs] def on_tool_end(self, output: Any, **kwargs: Any) -> None: """当工具运行结束时运行。""" output = str(output) self.metrics["step"] += 1 self.metrics["tool_ends"] += 1 self.metrics["ends"] += 1 tool_ends = self.metrics["tool_ends"] resp: Dict[str, Any] = {} resp.update({"action": "on_tool_end", "output": output}) resp.update(self.metrics) self.mlflg.metrics(self.metrics, step=self.metrics["step"]) self.records["on_tool_end_records"].append(resp) self.records["action_records"].append(resp) self.mlflg.jsonf(resp, f"tool_end_{tool_ends}")
[docs] def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: """当工具出现错误时运行。""" self.metrics["step"] += 1 self.metrics["errors"] += 1
[docs] def on_text(self, text: str, **kwargs: Any) -> None: """ 当接收到文本时运行。 """ self.metrics["step"] += 1 self.metrics["text_ctr"] += 1 text_ctr = self.metrics["text_ctr"] resp: Dict[str, Any] = {} resp.update({"action": "on_text", "text": text}) resp.update(self.metrics) self.mlflg.metrics(self.metrics, step=self.metrics["step"]) self.records["on_text_records"].append(resp) self.records["action_records"].append(resp) self.mlflg.jsonf(resp, f"on_text_{text_ctr}")
[docs] def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: """当代理程序运行结束时运行。""" self.metrics["step"] += 1 self.metrics["agent_ends"] += 1 self.metrics["ends"] += 1 agent_ends = self.metrics["agent_ends"] resp: Dict[str, Any] = {} resp.update( { "action": "on_agent_finish", "output": finish.return_values["output"], "log": finish.log, } ) resp.update(self.metrics) self.mlflg.metrics(self.metrics, step=self.metrics["step"]) self.records["on_agent_finish_records"].append(resp) self.records["action_records"].append(resp) self.mlflg.jsonf(resp, f"agent_finish_{agent_ends}")
[docs] def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: """在代理程序上运行的操作。""" self.metrics["step"] += 1 self.metrics["tool_starts"] += 1 self.metrics["starts"] += 1 tool_starts = self.metrics["tool_starts"] resp: Dict[str, Any] = {} resp.update( { "action": "on_agent_action", "tool": action.tool, "tool_input": action.tool_input, "log": action.log, } ) resp.update(self.metrics) self.mlflg.metrics(self.metrics, step=self.metrics["step"]) self.records["on_agent_action_records"].append(resp) self.records["action_records"].append(resp) self.mlflg.jsonf(resp, f"agent_action_{tool_starts}")
[docs] def on_retriever_start( self, serialized: Dict[str, Any], query: str, **kwargs: Any, ) -> Any: """当Retriever开始运行时运行。""" self.metrics["step"] += 1 self.metrics["retriever_starts"] += 1 self.metrics["starts"] += 1 retriever_starts = self.metrics["retriever_starts"] resp: Dict[str, Any] = {} resp.update({"action": "on_retriever_start", "query": query}) resp.update(flatten_dict(serialized)) resp.update(self.metrics) self.mlflg.metrics(self.metrics, step=self.metrics["step"]) self.records["on_retriever_start_records"].append(resp) self.records["action_records"].append(resp) self.mlflg.jsonf(resp, f"retriever_start_{retriever_starts}")
[docs] def on_retriever_end( self, documents: Sequence[Document], **kwargs: Any, ) -> Any: """当Retriever运行结束时运行。""" self.metrics["step"] += 1 self.metrics["retriever_ends"] += 1 self.metrics["ends"] += 1 retriever_ends = self.metrics["retriever_ends"] resp: Dict[str, Any] = {} retriever_documents = [ { "page_content": doc.page_content, "metadata": { k: ( str(v) if not isinstance(v, list) else ",".join(str(x) for x in v) ) for k, v in doc.metadata.items() }, } for doc in documents ] resp.update({"action": "on_retriever_end", "documents": retriever_documents}) resp.update(self.metrics) self.mlflg.metrics(self.metrics, step=self.metrics["step"]) self.records["on_retriever_end_records"].append(resp) self.records["action_records"].append(resp) self.mlflg.jsonf(resp, f"retriever_end_{retriever_ends}")
[docs] def on_retriever_error(self, error: BaseException, **kwargs: Any) -> Any: """当Retriever发生错误时运行。""" self.metrics["step"] += 1 self.metrics["errors"] += 1
def _create_session_analysis_df(self) -> Any: """使用会话中的所有信息创建一个数据框。""" pd = import_pandas() on_llm_start_records_df = pd.DataFrame(self.records["on_llm_start_records"]) on_llm_end_records_df = pd.DataFrame(self.records["on_llm_end_records"]) llm_input_columns = ["step", "prompt"] if "name" in on_llm_start_records_df.columns: llm_input_columns.append("name") elif "id" in on_llm_start_records_df.columns: # id is llm class's full import path. For example: # ["langchain", "llms", "openai", "AzureOpenAI"] on_llm_start_records_df["name"] = on_llm_start_records_df["id"].apply( lambda id_: id_[-1] ) llm_input_columns.append("name") llm_input_prompts_df = ( on_llm_start_records_df[llm_input_columns] .dropna(axis=1) .rename({"step": "prompt_step"}, axis=1) ) complexity_metrics_columns = ( get_text_complexity_metrics() if self.textstat is not None else [] ) visualizations_columns = ( ["dependency_tree", "entities"] if self.nlp is not None else [] ) token_usage_columns = [ "token_usage_total_tokens", "token_usage_prompt_tokens", "token_usage_completion_tokens", ] token_usage_columns = [ x for x in token_usage_columns if x in on_llm_end_records_df.columns ] llm_outputs_df = ( on_llm_end_records_df[ [ "step", "text", ] + token_usage_columns + complexity_metrics_columns + visualizations_columns ] .dropna(axis=1) .rename({"step": "output_step", "text": "output"}, axis=1) ) session_analysis_df = pd.concat([llm_input_prompts_df, llm_outputs_df], axis=1) session_analysis_df["chat_html"] = session_analysis_df[ ["prompt", "output"] ].apply( lambda row: construct_html_from_prompt_and_generation( row["prompt"], row["output"] ), axis=1, ) return session_analysis_df def _contain_llm_records(self) -> bool: return bool(self.records["on_llm_start_records"])
[docs] def flush_tracker(self, langchain_asset: Any = None, finish: bool = False) -> None: pd = import_pandas() self.mlflg.table("action_records", pd.DataFrame(self.records["action_records"])) if self._contain_llm_records(): session_analysis_df = self._create_session_analysis_df() chat_html = session_analysis_df.pop("chat_html") chat_html = chat_html.replace("\n", "", regex=True) self.mlflg.table("session_analysis", pd.DataFrame(session_analysis_df)) self.mlflg.html("".join(chat_html.tolist()), "chat_html") if langchain_asset: # To avoid circular import error # mlflow only supports LLMChain asset if "langchain.chains.llm.LLMChain" in str(type(langchain_asset)): self.mlflg.langchain_artifact(langchain_asset) else: langchain_asset_path = str(Path(self.temp_dir.name, "model.json")) try: langchain_asset.save(langchain_asset_path) self.mlflg.artifact(langchain_asset_path) except ValueError: try: langchain_asset.save_agent(langchain_asset_path) self.mlflg.artifact(langchain_asset_path) except AttributeError: print("Could not save model.") # noqa: T201 traceback.print_exc() pass except NotImplementedError: print("Could not save model.") # noqa: T201 traceback.print_exc() pass except NotImplementedError: print("Could not save model.") # noqa: T201 traceback.print_exc() pass if finish: self.mlflg.finish_run() self._reset()