from copy import deepcopy
from typing import Any, Dict, List, Optional
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult
from langchain_core.utils import guard_import
[docs]def import_aim() -> Any:
"""导入 aim python 包,并在未安装时引发错误。"""
return guard_import("aim")
[docs]class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
"""将日志记录到Aim的回调处理程序。
参数:
repo (:obj:`str`, optional): Aim存储库路径或绑定到Run对象的Repo对象。如果省略,默认使用Repo。
experiment_name (:obj:`str`, optional): 设置Run的`experiment`属性。如果未指定,默认为'default'。以后可用于查询运行/序列。
system_tracking_interval (:obj:`int`, optional): 设置系统使用情况指标(CPU、内存等)的跟踪间隔,单位为秒。设置为`None`以禁用系统指标跟踪。
log_system_params (:obj:`bool`, optional): 启用/禁用系统参数的记录,如已安装的软件包、git信息、环境变量等。
此处理程序将利用关联的回调方法,并使用关于LLM运行状态的元数据格式化每个回调函数的输入,然后将响应记录到Aim。"""
[docs] def __init__(
self,
repo: Optional[str] = None,
experiment_name: Optional[str] = None,
system_tracking_interval: Optional[int] = 10,
log_system_params: bool = True,
) -> None:
"""初始化回调处理程序。"""
super().__init__()
aim = import_aim()
self.repo = repo
self.experiment_name = experiment_name
self.system_tracking_interval = system_tracking_interval
self.log_system_params = log_system_params
self._run = aim.Run(
repo=self.repo,
experiment=self.experiment_name,
system_tracking_interval=self.system_tracking_interval,
log_system_params=self.log_system_params,
)
self._run_hash = self._run.hash
self.action_records: list = []
[docs] def setup(self, **kwargs: Any) -> None:
aim = import_aim()
if not self._run:
if self._run_hash:
self._run = aim.Run(
self._run_hash,
repo=self.repo,
system_tracking_interval=self.system_tracking_interval,
)
else:
self._run = aim.Run(
repo=self.repo,
experiment=self.experiment_name,
system_tracking_interval=self.system_tracking_interval,
log_system_params=self.log_system_params,
)
self._run_hash = self._run.hash
if kwargs:
for key, value in kwargs.items():
self._run.set(key, value, strict=False)
[docs] def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""LLM启动时运行。"""
aim = import_aim()
self.step += 1
self.llm_starts += 1
self.starts += 1
resp = {"action": "on_llm_start"}
resp.update(self.get_custom_callback_meta())
prompts_res = deepcopy(prompts)
self._run.track(
[aim.Text(prompt) for prompt in prompts_res],
name="on_llm_start",
context=resp,
)
[docs] def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""LLM 运行结束时运行。"""
aim = import_aim()
self.step += 1
self.llm_ends += 1
self.ends += 1
resp = {"action": "on_llm_end"}
resp.update(self.get_custom_callback_meta())
response_res = deepcopy(response)
generated = [
aim.Text(generation.text)
for generations in response_res.generations
for generation in generations
]
self._run.track(
generated,
name="on_llm_end",
context=resp,
)
[docs] def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""当LLM生成一个新的令牌时运行。"""
self.step += 1
self.llm_streams += 1
[docs] def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""当LLM出现错误时运行。"""
self.step += 1
self.errors += 1
[docs] def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""当链开始运行时运行。"""
aim = import_aim()
self.step += 1
self.chain_starts += 1
self.starts += 1
resp = {"action": "on_chain_start"}
resp.update(self.get_custom_callback_meta())
inputs_res = deepcopy(inputs)
self._run.track(
aim.Text(inputs_res["input"]), name="on_chain_start", context=resp
)
[docs] def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""当链结束运行时运行。"""
aim = import_aim()
self.step += 1
self.chain_ends += 1
self.ends += 1
resp = {"action": "on_chain_end"}
resp.update(self.get_custom_callback_meta())
outputs_res = deepcopy(outputs)
self._run.track(
aim.Text(outputs_res["output"]), name="on_chain_end", context=resp
)
[docs] def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""当链式错误时运行。"""
self.step += 1
self.errors += 1
[docs] def on_text(self, text: str, **kwargs: Any) -> None:
"""
当代理程序结束时运行。
"""
self.step += 1
self.text_ctr += 1
[docs] def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""当代理程序运行结束时运行。"""
aim = import_aim()
self.step += 1
self.agent_ends += 1
self.ends += 1
resp = {"action": "on_agent_finish"}
resp.update(self.get_custom_callback_meta())
finish_res = deepcopy(finish)
text = "OUTPUT:\n{}\n\nLOG:\n{}".format(
finish_res.return_values["output"], finish_res.log
)
self._run.track(aim.Text(text), name="on_agent_finish", context=resp)
[docs] def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""在代理程序上运行的操作。"""
aim = import_aim()
self.step += 1
self.tool_starts += 1
self.starts += 1
resp = {
"action": "on_agent_action",
"tool": action.tool,
}
resp.update(self.get_custom_callback_meta())
action_res = deepcopy(action)
text = "TOOL INPUT:\n{}\n\nLOG:\n{}".format(
action_res.tool_input, action_res.log
)
self._run.track(aim.Text(text), name="on_agent_action", context=resp)
[docs] def flush_tracker(
self,
repo: Optional[str] = None,
experiment_name: Optional[str] = None,
system_tracking_interval: Optional[int] = 10,
log_system_params: bool = True,
langchain_asset: Any = None,
reset: bool = True,
finish: bool = False,
) -> None:
"""刷新跟踪器并重置会话。
参数:
repo (:obj:`str`, optional): 目标存储库路径或与Run对象绑定的Repo对象。如果省略,默认使用Repo。
experiment_name (:obj:`str`, optional): 设置Run的`experiment`属性。如果未指定,默认为'default'。以后可用于查询运行/序列。
system_tracking_interval (:obj:`int`, optional): 设置系统使用情况指标(CPU、内存等)的跟踪间隔,单位为秒。设置为`None`以禁用系统指标跟踪。
log_system_params (:obj:`bool`, optional): 启用/禁用系统参数的记录,如已安装的软件包、git信息、环境变量等。
langchain_asset: 要保存的langchain资产。
reset: 是否重置会话。
finish: 是否完成运行。
返回:
None
"""
if langchain_asset:
try:
for key, value in langchain_asset.dict().items():
self._run.set(key, value, strict=False)
except Exception:
pass
if finish or reset:
self._run.close()
self.reset_callback_meta()
if reset:
aim = import_aim()
self.repo = repo if repo else self.repo
self.experiment_name = (
experiment_name if experiment_name else self.experiment_name
)
self.system_tracking_interval = (
system_tracking_interval
if system_tracking_interval
else self.system_tracking_interval
)
self.log_system_params = (
log_system_params if log_system_params else self.log_system_params
)
self._run = aim.Run(
repo=self.repo,
experiment=self.experiment_name,
system_tracking_interval=self.system_tracking_interval,
log_system_params=self.log_system_params,
)
self._run_hash = self._run.hash
self.action_records = []