Source code for langchain_community.callbacks.aim_callback

from copy import deepcopy
from typing import Any, Dict, List, Optional

from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult
from langchain_core.utils import guard_import


[docs]def import_aim() -> Any: """导入 aim python 包,并在未安装时引发错误。""" return guard_import("aim")
[docs]class BaseMetadataCallbackHandler: """回调处理程序,用于回调的元数据和相关函数状态。 属性: step(int):当前步骤。 starts(int):调用start方法的次数。 ends(int):调用end方法的次数。 errors(int):调用error方法的次数。 text_ctr(int):调用text方法的次数。 ignore_llm_(bool):是否忽略llm回调。 ignore_chain_(bool):是否忽略链回调。 ignore_agent_(bool):是否忽略代理回调。 ignore_retriever_(bool):是否忽略检索器回调。 always_verbose_(bool):是否始终详细。 chain_starts(int):调用chain start方法的次数。 chain_ends(int):调用chain end方法的次数。 llm_starts(int):调用llm start方法的次数。 llm_ends(int):调用llm end方法的次数。 llm_streams(int):调用text方法的次数。 tool_starts(int):调用tool start方法的次数。 tool_ends(int):调用tool end方法的次数。 agent_ends(int):调用agent end方法的次数。"""
[docs] def __init__(self) -> None: self.step = 0 self.starts = 0 self.ends = 0 self.errors = 0 self.text_ctr = 0 self.ignore_llm_ = False self.ignore_chain_ = False self.ignore_agent_ = False self.ignore_retriever_ = False self.always_verbose_ = False self.chain_starts = 0 self.chain_ends = 0 self.llm_starts = 0 self.llm_ends = 0 self.llm_streams = 0 self.tool_starts = 0 self.tool_ends = 0 self.agent_ends = 0
@property def always_verbose(self) -> bool: """即使 verbose 为 False,也要调用详细回调函数。""" return self.always_verbose_ @property def ignore_llm(self) -> bool: """是否忽略LLM回调。""" return self.ignore_llm_ @property def ignore_chain(self) -> bool: """是否忽略链式回调。""" return self.ignore_chain_ @property def ignore_agent(self) -> bool: """是否忽略代理回调。""" return self.ignore_agent_ @property def ignore_retriever(self) -> bool: """是否忽略检索器的回调函数。""" return self.ignore_retriever_
[docs] def get_custom_callback_meta(self) -> Dict[str, Any]: return { "step": self.step, "starts": self.starts, "ends": self.ends, "errors": self.errors, "text_ctr": self.text_ctr, "chain_starts": self.chain_starts, "chain_ends": self.chain_ends, "llm_starts": self.llm_starts, "llm_ends": self.llm_ends, "llm_streams": self.llm_streams, "tool_starts": self.tool_starts, "tool_ends": self.tool_ends, "agent_ends": self.agent_ends, }
[docs] def reset_callback_meta(self) -> None: """重置回调元数据。""" self.step = 0 self.starts = 0 self.ends = 0 self.errors = 0 self.text_ctr = 0 self.ignore_llm_ = False self.ignore_chain_ = False self.ignore_agent_ = False self.always_verbose_ = False self.chain_starts = 0 self.chain_ends = 0 self.llm_starts = 0 self.llm_ends = 0 self.llm_streams = 0 self.tool_starts = 0 self.tool_ends = 0 self.agent_ends = 0 return None
[docs]class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): """将日志记录到Aim的回调处理程序。 参数: repo (:obj:`str`, optional): Aim存储库路径或绑定到Run对象的Repo对象。如果省略,默认使用Repo。 experiment_name (:obj:`str`, optional): 设置Run的`experiment`属性。如果未指定,默认为'default'。以后可用于查询运行/序列。 system_tracking_interval (:obj:`int`, optional): 设置系统使用情况指标(CPU、内存等)的跟踪间隔,单位为秒。设置为`None`以禁用系统指标跟踪。 log_system_params (:obj:`bool`, optional): 启用/禁用系统参数的记录,如已安装的软件包、git信息、环境变量等。 此处理程序将利用关联的回调方法,并使用关于LLM运行状态的元数据格式化每个回调函数的输入,然后将响应记录到Aim。"""
[docs] def __init__( self, repo: Optional[str] = None, experiment_name: Optional[str] = None, system_tracking_interval: Optional[int] = 10, log_system_params: bool = True, ) -> None: """初始化回调处理程序。""" super().__init__() aim = import_aim() self.repo = repo self.experiment_name = experiment_name self.system_tracking_interval = system_tracking_interval self.log_system_params = log_system_params self._run = aim.Run( repo=self.repo, experiment=self.experiment_name, system_tracking_interval=self.system_tracking_interval, log_system_params=self.log_system_params, ) self._run_hash = self._run.hash self.action_records: list = []
[docs] def setup(self, **kwargs: Any) -> None: aim = import_aim() if not self._run: if self._run_hash: self._run = aim.Run( self._run_hash, repo=self.repo, system_tracking_interval=self.system_tracking_interval, ) else: self._run = aim.Run( repo=self.repo, experiment=self.experiment_name, system_tracking_interval=self.system_tracking_interval, log_system_params=self.log_system_params, ) self._run_hash = self._run.hash if kwargs: for key, value in kwargs.items(): self._run.set(key, value, strict=False)
[docs] def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any ) -> None: """LLM启动时运行。""" aim = import_aim() self.step += 1 self.llm_starts += 1 self.starts += 1 resp = {"action": "on_llm_start"} resp.update(self.get_custom_callback_meta()) prompts_res = deepcopy(prompts) self._run.track( [aim.Text(prompt) for prompt in prompts_res], name="on_llm_start", context=resp, )
[docs] def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: """LLM 运行结束时运行。""" aim = import_aim() self.step += 1 self.llm_ends += 1 self.ends += 1 resp = {"action": "on_llm_end"} resp.update(self.get_custom_callback_meta()) response_res = deepcopy(response) generated = [ aim.Text(generation.text) for generations in response_res.generations for generation in generations ] self._run.track( generated, name="on_llm_end", context=resp, )
[docs] def on_llm_new_token(self, token: str, **kwargs: Any) -> None: """当LLM生成一个新的令牌时运行。""" self.step += 1 self.llm_streams += 1
[docs] def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: """当LLM出现错误时运行。""" self.step += 1 self.errors += 1
[docs] def on_chain_start( self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any ) -> None: """当链开始运行时运行。""" aim = import_aim() self.step += 1 self.chain_starts += 1 self.starts += 1 resp = {"action": "on_chain_start"} resp.update(self.get_custom_callback_meta()) inputs_res = deepcopy(inputs) self._run.track( aim.Text(inputs_res["input"]), name="on_chain_start", context=resp )
[docs] def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: """当链结束运行时运行。""" aim = import_aim() self.step += 1 self.chain_ends += 1 self.ends += 1 resp = {"action": "on_chain_end"} resp.update(self.get_custom_callback_meta()) outputs_res = deepcopy(outputs) self._run.track( aim.Text(outputs_res["output"]), name="on_chain_end", context=resp )
[docs] def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: """当链式错误时运行。""" self.step += 1 self.errors += 1
[docs] def on_tool_start( self, serialized: Dict[str, Any], input_str: str, **kwargs: Any ) -> None: """当工具开始运行时运行。""" aim = import_aim() self.step += 1 self.tool_starts += 1 self.starts += 1 resp = {"action": "on_tool_start"} resp.update(self.get_custom_callback_meta()) self._run.track(aim.Text(input_str), name="on_tool_start", context=resp)
[docs] def on_tool_end(self, output: Any, **kwargs: Any) -> None: """当工具运行结束时运行。""" output = str(output) aim = import_aim() self.step += 1 self.tool_ends += 1 self.ends += 1 resp = {"action": "on_tool_end"} resp.update(self.get_custom_callback_meta()) self._run.track(aim.Text(output), name="on_tool_end", context=resp)
[docs] def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: """当工具出现错误时运行。""" self.step += 1 self.errors += 1
[docs] def on_text(self, text: str, **kwargs: Any) -> None: """ 当代理程序结束时运行。 """ self.step += 1 self.text_ctr += 1
[docs] def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: """当代理程序运行结束时运行。""" aim = import_aim() self.step += 1 self.agent_ends += 1 self.ends += 1 resp = {"action": "on_agent_finish"} resp.update(self.get_custom_callback_meta()) finish_res = deepcopy(finish) text = "OUTPUT:\n{}\n\nLOG:\n{}".format( finish_res.return_values["output"], finish_res.log ) self._run.track(aim.Text(text), name="on_agent_finish", context=resp)
[docs] def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: """在代理程序上运行的操作。""" aim = import_aim() self.step += 1 self.tool_starts += 1 self.starts += 1 resp = { "action": "on_agent_action", "tool": action.tool, } resp.update(self.get_custom_callback_meta()) action_res = deepcopy(action) text = "TOOL INPUT:\n{}\n\nLOG:\n{}".format( action_res.tool_input, action_res.log ) self._run.track(aim.Text(text), name="on_agent_action", context=resp)
[docs] def flush_tracker( self, repo: Optional[str] = None, experiment_name: Optional[str] = None, system_tracking_interval: Optional[int] = 10, log_system_params: bool = True, langchain_asset: Any = None, reset: bool = True, finish: bool = False, ) -> None: """刷新跟踪器并重置会话。 参数: repo (:obj:`str`, optional): 目标存储库路径或与Run对象绑定的Repo对象。如果省略,默认使用Repo。 experiment_name (:obj:`str`, optional): 设置Run的`experiment`属性。如果未指定,默认为'default'。以后可用于查询运行/序列。 system_tracking_interval (:obj:`int`, optional): 设置系统使用情况指标(CPU、内存等)的跟踪间隔,单位为秒。设置为`None`以禁用系统指标跟踪。 log_system_params (:obj:`bool`, optional): 启用/禁用系统参数的记录,如已安装的软件包、git信息、环境变量等。 langchain_asset: 要保存的langchain资产。 reset: 是否重置会话。 finish: 是否完成运行。 返回: None """ if langchain_asset: try: for key, value in langchain_asset.dict().items(): self._run.set(key, value, strict=False) except Exception: pass if finish or reset: self._run.close() self.reset_callback_meta() if reset: aim = import_aim() self.repo = repo if repo else self.repo self.experiment_name = ( experiment_name if experiment_name else self.experiment_name ) self.system_tracking_interval = ( system_tracking_interval if system_tracking_interval else self.system_tracking_interval ) self.log_system_params = ( log_system_params if log_system_params else self.log_system_params ) self._run = aim.Run( repo=self.repo, experiment=self.experiment_name, system_tracking_interval=self.system_tracking_interval, log_system_params=self.log_system_params, ) self._run_hash = self._run.hash self.action_records = []