Source code for langchain_community.callbacks.wandb_callback

import json
import tempfile
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Union

from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult
from langchain_core.utils import guard_import

from langchain_community.callbacks.utils import (
    BaseMetadataCallbackHandler,
    flatten_dict,
    hash_string,
    import_pandas,
    import_spacy,
    import_textstat,
)


[docs]def import_wandb() -> Any: """导入wandb Python包,并在未安装时引发错误。""" return guard_import("wandb")
[docs]def load_json_to_dict(json_path: Union[str, Path]) -> dict: """将json文件加载到字典中。 参数: json_path(str):json文件的路径。 返回: (dict):json文件的字典表示。 """ with open(json_path, "r") as f: data = json.load(f) return data
[docs]def analyze_text( text: str, complexity_metrics: bool = True, visualize: bool = True, nlp: Any = None, output_dir: Optional[Union[str, Path]] = None, ) -> dict: """使用textstat和spacy分析文本。 参数: text(str):要分析的文本。 complexity_metrics(bool):是否计算复杂度指标。 visualize(bool):是否可视化文本。 nlp(spacy.lang):用于可视化的spacy语言模型。 output_dir(str):保存可视化文件的目录。 返回: (dict):包含复杂度指标和可视化文件的字典,序列化为wandb.Html元素。 """ resp = {} textstat = import_textstat() wandb = import_wandb() spacy = import_spacy() if complexity_metrics: text_complexity_metrics = { "flesch_reading_ease": textstat.flesch_reading_ease(text), "flesch_kincaid_grade": textstat.flesch_kincaid_grade(text), "smog_index": textstat.smog_index(text), "coleman_liau_index": textstat.coleman_liau_index(text), "automated_readability_index": textstat.automated_readability_index(text), "dale_chall_readability_score": textstat.dale_chall_readability_score(text), "difficult_words": textstat.difficult_words(text), "linsear_write_formula": textstat.linsear_write_formula(text), "gunning_fog": textstat.gunning_fog(text), "text_standard": textstat.text_standard(text), "fernandez_huerta": textstat.fernandez_huerta(text), "szigriszt_pazos": textstat.szigriszt_pazos(text), "gutierrez_polini": textstat.gutierrez_polini(text), "crawford": textstat.crawford(text), "gulpease_index": textstat.gulpease_index(text), "osman": textstat.osman(text), } resp.update(text_complexity_metrics) if visualize and nlp and output_dir is not None: doc = nlp(text) dep_out = spacy.displacy.render(doc, style="dep", jupyter=False, page=True) dep_output_path = Path(output_dir, hash_string(f"dep-{text}") + ".html") dep_output_path.open("w", encoding="utf-8").write(dep_out) ent_out = spacy.displacy.render(doc, style="ent", jupyter=False, page=True) ent_output_path = Path(output_dir, hash_string(f"ent-{text}") + ".html") ent_output_path.open("w", encoding="utf-8").write(ent_out) text_visualizations = { "dependency_tree": wandb.Html(str(dep_output_path)), "entities": wandb.Html(str(ent_output_path)), } resp.update(text_visualizations) return resp
[docs]def construct_html_from_prompt_and_generation(prompt: str, generation: str) -> Any: """从提示和生成中构建一个html元素。 参数: prompt(str):提示。 generation(str):生成。 返回: (wandb.Html):html元素。 """ wandb = import_wandb() formatted_prompt = prompt.replace("\n", "<br>") formatted_generation = generation.replace("\n", "<br>") return wandb.Html( f""" <p style="color:black;">{formatted_prompt}:</p> <blockquote> <p style="color:green;"> {formatted_generation} </p> </blockquote> """, inject=False, )
[docs]class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): """回调处理程序,用于记录到Weights and Biases。 参数: job_type (str): 作业类型。 project (str): 要记录的项目。 entity (str): 要记录的实体。 tags (list): 要记录的标签。 group (str): 要记录的组。 name (str): 运行的名称。 notes (str): 要记录的注释。 visualize (bool): 是否可视化运行。 complexity_metrics (bool): 是否记录复杂度指标。 stream_logs (bool): 是否将回调操作流式传输到W&B。 此处理程序将利用关联的回调方法,并格式化每个回调函数的输入,其中包含有关LLM运行状态的元数据,并将响应添加到{method}_records和操作的记录列表中。然后,使用run.log()方法将响应记录到Weights and Biases。"""
[docs] def __init__( self, job_type: Optional[str] = None, project: Optional[str] = "langchain_callback_demo", entity: Optional[str] = None, tags: Optional[Sequence] = None, group: Optional[str] = None, name: Optional[str] = None, notes: Optional[str] = None, visualize: bool = False, complexity_metrics: bool = False, stream_logs: bool = False, ) -> None: """初始化回调处理程序。""" wandb = import_wandb() import_pandas() import_textstat() spacy = import_spacy() super().__init__() self.job_type = job_type self.project = project self.entity = entity self.tags = tags self.group = group self.name = name self.notes = notes self.visualize = visualize self.complexity_metrics = complexity_metrics self.stream_logs = stream_logs self.temp_dir = tempfile.TemporaryDirectory() self.run = wandb.init( job_type=self.job_type, project=self.project, entity=self.entity, tags=self.tags, group=self.group, name=self.name, notes=self.notes, ) warning = ( "DEPRECATION: The `WandbCallbackHandler` will soon be deprecated in favor " "of the `WandbTracer`. Please update your code to use the `WandbTracer` " "instead." ) wandb.termwarn( warning, repeat=False, ) self.callback_columns: list = [] self.action_records: list = [] self.complexity_metrics = complexity_metrics self.visualize = visualize self.nlp = spacy.load("en_core_web_sm")
def _init_resp(self) -> Dict: return {k: None for k in self.callback_columns}
[docs] def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any ) -> None: """LLM启动时运行。""" self.step += 1 self.llm_starts += 1 self.starts += 1 resp = self._init_resp() resp.update({"action": "on_llm_start"}) resp.update(flatten_dict(serialized)) resp.update(self.get_custom_callback_meta()) for prompt in prompts: prompt_resp = deepcopy(resp) prompt_resp["prompts"] = prompt self.on_llm_start_records.append(prompt_resp) self.action_records.append(prompt_resp) if self.stream_logs: self.run.log(prompt_resp)
[docs] def on_llm_new_token(self, token: str, **kwargs: Any) -> None: """当LLM生成一个新的令牌时运行。""" self.step += 1 self.llm_streams += 1 resp = self._init_resp() resp.update({"action": "on_llm_new_token", "token": token}) resp.update(self.get_custom_callback_meta()) self.on_llm_token_records.append(resp) self.action_records.append(resp) if self.stream_logs: self.run.log(resp)
[docs] def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: """LLM 运行结束时运行。""" self.step += 1 self.llm_ends += 1 self.ends += 1 resp = self._init_resp() resp.update({"action": "on_llm_end"}) resp.update(flatten_dict(response.llm_output or {})) resp.update(self.get_custom_callback_meta()) for generations in response.generations: for generation in generations: generation_resp = deepcopy(resp) generation_resp.update(flatten_dict(generation.dict())) generation_resp.update( analyze_text( generation.text, complexity_metrics=self.complexity_metrics, visualize=self.visualize, nlp=self.nlp, output_dir=self.temp_dir.name, ) ) self.on_llm_end_records.append(generation_resp) self.action_records.append(generation_resp) if self.stream_logs: self.run.log(generation_resp)
[docs] def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: """当LLM出现错误时运行。""" self.step += 1 self.errors += 1
[docs] def on_chain_start( self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any ) -> None: """当链开始运行时运行。""" self.step += 1 self.chain_starts += 1 self.starts += 1 resp = self._init_resp() resp.update({"action": "on_chain_start"}) resp.update(flatten_dict(serialized)) resp.update(self.get_custom_callback_meta()) chain_input = inputs["input"] if isinstance(chain_input, str): input_resp = deepcopy(resp) input_resp["input"] = chain_input self.on_chain_start_records.append(input_resp) self.action_records.append(input_resp) if self.stream_logs: self.run.log(input_resp) elif isinstance(chain_input, list): for inp in chain_input: input_resp = deepcopy(resp) input_resp.update(inp) self.on_chain_start_records.append(input_resp) self.action_records.append(input_resp) if self.stream_logs: self.run.log(input_resp) else: raise ValueError("Unexpected data format provided!")
[docs] def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: """当链结束运行时运行。""" self.step += 1 self.chain_ends += 1 self.ends += 1 resp = self._init_resp() resp.update({"action": "on_chain_end", "outputs": outputs["output"]}) resp.update(self.get_custom_callback_meta()) self.on_chain_end_records.append(resp) self.action_records.append(resp) if self.stream_logs: self.run.log(resp)
[docs] def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: """当链式错误时运行。""" self.step += 1 self.errors += 1
[docs] def on_tool_start( self, serialized: Dict[str, Any], input_str: str, **kwargs: Any ) -> None: """当工具开始运行时运行。""" self.step += 1 self.tool_starts += 1 self.starts += 1 resp = self._init_resp() resp.update({"action": "on_tool_start", "input_str": input_str}) resp.update(flatten_dict(serialized)) resp.update(self.get_custom_callback_meta()) self.on_tool_start_records.append(resp) self.action_records.append(resp) if self.stream_logs: self.run.log(resp)
[docs] def on_tool_end(self, output: Any, **kwargs: Any) -> None: """当工具运行结束时运行。""" output = str(output) self.step += 1 self.tool_ends += 1 self.ends += 1 resp = self._init_resp() resp.update({"action": "on_tool_end", "output": output}) resp.update(self.get_custom_callback_meta()) self.on_tool_end_records.append(resp) self.action_records.append(resp) if self.stream_logs: self.run.log(resp)
[docs] def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: """当工具出现错误时运行。""" self.step += 1 self.errors += 1
[docs] def on_text(self, text: str, **kwargs: Any) -> None: """ 当代理程序结束时运行。 """ self.step += 1 self.text_ctr += 1 resp = self._init_resp() resp.update({"action": "on_text", "text": text}) resp.update(self.get_custom_callback_meta()) self.on_text_records.append(resp) self.action_records.append(resp) if self.stream_logs: self.run.log(resp)
[docs] def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: """当代理程序运行结束时运行。""" self.step += 1 self.agent_ends += 1 self.ends += 1 resp = self._init_resp() resp.update( { "action": "on_agent_finish", "output": finish.return_values["output"], "log": finish.log, } ) resp.update(self.get_custom_callback_meta()) self.on_agent_finish_records.append(resp) self.action_records.append(resp) if self.stream_logs: self.run.log(resp)
[docs] def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: """在代理程序上运行的操作。""" self.step += 1 self.tool_starts += 1 self.starts += 1 resp = self._init_resp() resp.update( { "action": "on_agent_action", "tool": action.tool, "tool_input": action.tool_input, "log": action.log, } ) resp.update(self.get_custom_callback_meta()) self.on_agent_action_records.append(resp) self.action_records.append(resp) if self.stream_logs: self.run.log(resp)
def _create_session_analysis_df(self) -> Any: """使用会话中的所有信息创建一个数据框。""" pd = import_pandas() on_llm_start_records_df = pd.DataFrame(self.on_llm_start_records) on_llm_end_records_df = pd.DataFrame(self.on_llm_end_records) llm_input_prompts_df = ( on_llm_start_records_df[["step", "prompts", "name"]] .dropna(axis=1) .rename({"step": "prompt_step"}, axis=1) ) complexity_metrics_columns = [] visualizations_columns = [] if self.complexity_metrics: complexity_metrics_columns = [ "flesch_reading_ease", "flesch_kincaid_grade", "smog_index", "coleman_liau_index", "automated_readability_index", "dale_chall_readability_score", "difficult_words", "linsear_write_formula", "gunning_fog", "text_standard", "fernandez_huerta", "szigriszt_pazos", "gutierrez_polini", "crawford", "gulpease_index", "osman", ] if self.visualize: visualizations_columns = ["dependency_tree", "entities"] llm_outputs_df = ( on_llm_end_records_df[ [ "step", "text", "token_usage_total_tokens", "token_usage_prompt_tokens", "token_usage_completion_tokens", ] + complexity_metrics_columns + visualizations_columns ] .dropna(axis=1) .rename({"step": "output_step", "text": "output"}, axis=1) ) session_analysis_df = pd.concat([llm_input_prompts_df, llm_outputs_df], axis=1) session_analysis_df["chat_html"] = session_analysis_df[ ["prompts", "output"] ].apply( lambda row: construct_html_from_prompt_and_generation( row["prompts"], row["output"] ), axis=1, ) return session_analysis_df
[docs] def flush_tracker( self, langchain_asset: Any = None, reset: bool = True, finish: bool = False, job_type: Optional[str] = None, project: Optional[str] = None, entity: Optional[str] = None, tags: Optional[Sequence] = None, group: Optional[str] = None, name: Optional[str] = None, notes: Optional[str] = None, visualize: Optional[bool] = None, complexity_metrics: Optional[bool] = None, ) -> None: """刷新跟踪器并重置会话。 参数: langchain_asset: 要保存的langchain资产。 reset: 是否重置会话。 finish: 是否完成运行。 job_type: 作业类型。 project: 项目。 entity: 实体。 tags: 标签。 group: 组。 name: 名称。 notes: 注释。 visualize: 是否可视化。 complexity_metrics: 是否计算复杂度指标。 返回: 无。 """ pd = import_pandas() wandb = import_wandb() action_records_table = wandb.Table(dataframe=pd.DataFrame(self.action_records)) session_analysis_table = wandb.Table( dataframe=self._create_session_analysis_df() ) self.run.log( { "action_records": action_records_table, "session_analysis": session_analysis_table, } ) if langchain_asset: langchain_asset_path = Path(self.temp_dir.name, "model.json") model_artifact = wandb.Artifact(name="model", type="model") model_artifact.add(action_records_table, name="action_records") model_artifact.add(session_analysis_table, name="session_analysis") try: langchain_asset.save(langchain_asset_path) model_artifact.add_file(str(langchain_asset_path)) model_artifact.metadata = load_json_to_dict(langchain_asset_path) except ValueError: langchain_asset.save_agent(langchain_asset_path) model_artifact.add_file(str(langchain_asset_path)) model_artifact.metadata = load_json_to_dict(langchain_asset_path) except NotImplementedError as e: print("Could not save model.") # noqa: T201 print(repr(e)) # noqa: T201 pass self.run.log_artifact(model_artifact) if finish or reset: self.run.finish() self.temp_dir.cleanup() self.reset_callback_meta() if reset: self.__init__( # type: ignore job_type=job_type if job_type else self.job_type, project=project if project else self.project, entity=entity if entity else self.entity, tags=tags if tags else self.tags, group=group if group else self.group, name=name if name else self.name, notes=notes if notes else self.notes, visualize=visualize if visualize else self.visualize, complexity_metrics=( complexity_metrics if complexity_metrics else self.complexity_metrics ), )