Source code for langchain_community.callbacks.sagemaker_callback

import json
import os
import shutil
import tempfile
from copy import deepcopy
from typing import Any, Dict, List, Optional

from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult

from langchain_community.callbacks.utils import (
    flatten_dict,
)


[docs]def save_json(data: dict, file_path: str) -> None: """将字典保存到本地文件路径。 参数: data(字典):要保存的字典。 file_path(str):本地文件路径。 """ with open(file_path, "w") as outfile: json.dump(data, outfile)
[docs]class SageMakerCallbackHandler(BaseCallbackHandler): """回调处理程序,用于将提示工件和指标记录到SageMaker实验中。 参数: run (sagemaker.experiments.run.Run): 记录实验的运行对象。"""
[docs] def __init__(self, run: Any) -> None: """初始化回调处理程序。""" super().__init__() self.run = run self.metrics = { "step": 0, "starts": 0, "ends": 0, "errors": 0, "text_ctr": 0, "chain_starts": 0, "chain_ends": 0, "llm_starts": 0, "llm_ends": 0, "llm_streams": 0, "tool_starts": 0, "tool_ends": 0, "agent_ends": 0, } # Create a temporary directory self.temp_dir = tempfile.mkdtemp()
def _reset(self) -> None: for k, v in self.metrics.items(): self.metrics[k] = 0
[docs] def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any ) -> None: """LLM启动时运行。""" self.metrics["step"] += 1 self.metrics["llm_starts"] += 1 self.metrics["starts"] += 1 llm_starts = self.metrics["llm_starts"] resp: Dict[str, Any] = {} resp.update({"action": "on_llm_start"}) resp.update(flatten_dict(serialized)) resp.update(self.metrics) for idx, prompt in enumerate(prompts): prompt_resp = deepcopy(resp) prompt_resp["prompt"] = prompt self.jsonf( prompt_resp, self.temp_dir, f"llm_start_{llm_starts}_prompt_{idx}", )
[docs] def on_llm_new_token(self, token: str, **kwargs: Any) -> None: """当LLM生成一个新的令牌时运行。""" self.metrics["step"] += 1 self.metrics["llm_streams"] += 1 llm_streams = self.metrics["llm_streams"] resp: Dict[str, Any] = {} resp.update({"action": "on_llm_new_token", "token": token}) resp.update(self.metrics) self.jsonf(resp, self.temp_dir, f"llm_new_tokens_{llm_streams}")
[docs] def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: """LLM 运行结束时运行。""" self.metrics["step"] += 1 self.metrics["llm_ends"] += 1 self.metrics["ends"] += 1 llm_ends = self.metrics["llm_ends"] resp: Dict[str, Any] = {} resp.update({"action": "on_llm_end"}) resp.update(flatten_dict(response.llm_output or {})) resp.update(self.metrics) for generations in response.generations: for idx, generation in enumerate(generations): generation_resp = deepcopy(resp) generation_resp.update(flatten_dict(generation.dict())) self.jsonf( resp, self.temp_dir, f"llm_end_{llm_ends}_generation_{idx}", )
[docs] def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: """当LLM出现错误时运行。""" self.metrics["step"] += 1 self.metrics["errors"] += 1
[docs] def on_chain_start( self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any ) -> None: """当链开始运行时运行。""" self.metrics["step"] += 1 self.metrics["chain_starts"] += 1 self.metrics["starts"] += 1 chain_starts = self.metrics["chain_starts"] resp: Dict[str, Any] = {} resp.update({"action": "on_chain_start"}) resp.update(flatten_dict(serialized)) resp.update(self.metrics) chain_input = ",".join([f"{k}={v}" for k, v in inputs.items()]) input_resp = deepcopy(resp) input_resp["inputs"] = chain_input self.jsonf(input_resp, self.temp_dir, f"chain_start_{chain_starts}")
[docs] def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: """当链结束运行时运行。""" self.metrics["step"] += 1 self.metrics["chain_ends"] += 1 self.metrics["ends"] += 1 chain_ends = self.metrics["chain_ends"] resp: Dict[str, Any] = {} chain_output = ",".join([f"{k}={v}" for k, v in outputs.items()]) resp.update({"action": "on_chain_end", "outputs": chain_output}) resp.update(self.metrics) self.jsonf(resp, self.temp_dir, f"chain_end_{chain_ends}")
[docs] def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: """当链式错误时运行。""" self.metrics["step"] += 1 self.metrics["errors"] += 1
[docs] def on_tool_start( self, serialized: Dict[str, Any], input_str: str, **kwargs: Any ) -> None: """当工具开始运行时运行。""" self.metrics["step"] += 1 self.metrics["tool_starts"] += 1 self.metrics["starts"] += 1 tool_starts = self.metrics["tool_starts"] resp: Dict[str, Any] = {} resp.update({"action": "on_tool_start", "input_str": input_str}) resp.update(flatten_dict(serialized)) resp.update(self.metrics) self.jsonf(resp, self.temp_dir, f"tool_start_{tool_starts}")
[docs] def on_tool_end(self, output: Any, **kwargs: Any) -> None: """当工具运行结束时运行。""" output = str(output) self.metrics["step"] += 1 self.metrics["tool_ends"] += 1 self.metrics["ends"] += 1 tool_ends = self.metrics["tool_ends"] resp: Dict[str, Any] = {} resp.update({"action": "on_tool_end", "output": output}) resp.update(self.metrics) self.jsonf(resp, self.temp_dir, f"tool_end_{tool_ends}")
[docs] def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: """当工具出现错误时运行。""" self.metrics["step"] += 1 self.metrics["errors"] += 1
[docs] def on_text(self, text: str, **kwargs: Any) -> None: """ 当代理程序结束时运行。 """ self.metrics["step"] += 1 self.metrics["text_ctr"] += 1 text_ctr = self.metrics["text_ctr"] resp: Dict[str, Any] = {} resp.update({"action": "on_text", "text": text}) resp.update(self.metrics) self.jsonf(resp, self.temp_dir, f"on_text_{text_ctr}")
[docs] def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: """当代理程序运行结束时运行。""" self.metrics["step"] += 1 self.metrics["agent_ends"] += 1 self.metrics["ends"] += 1 agent_ends = self.metrics["agent_ends"] resp: Dict[str, Any] = {} resp.update( { "action": "on_agent_finish", "output": finish.return_values["output"], "log": finish.log, } ) resp.update(self.metrics) self.jsonf(resp, self.temp_dir, f"agent_finish_{agent_ends}")
[docs] def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: """在代理程序上运行的操作。""" self.metrics["step"] += 1 self.metrics["tool_starts"] += 1 self.metrics["starts"] += 1 tool_starts = self.metrics["tool_starts"] resp: Dict[str, Any] = {} resp.update( { "action": "on_agent_action", "tool": action.tool, "tool_input": action.tool_input, "log": action.log, } ) resp.update(self.metrics) self.jsonf(resp, self.temp_dir, f"agent_action_{tool_starts}")
[docs] def jsonf( self, data: Dict[str, Any], data_dir: str, filename: str, is_output: Optional[bool] = True, ) -> None: """将输入数据记录为json文件工件。""" file_path = os.path.join(data_dir, f"{filename}.json") save_json(data, file_path) self.run.log_file(file_path, name=filename, is_output=is_output)
[docs] def flush_tracker(self) -> None: """重置步骤并删除临时本地目录。""" self._reset() shutil.rmtree(self.temp_dir)