Source code for langchain_community.callbacks.tracers.wandb

"""一个记录活动到Weights & Biases的跟踪器实现。"""
from __future__ import annotations

import json
from typing import (
    TYPE_CHECKING,
    Any,
    Dict,
    List,
    Optional,
    Sequence,
    Tuple,
    TypedDict,
    Union,
)

from langchain_core.tracers.base import BaseTracer
from langchain_core.tracers.schemas import Run

if TYPE_CHECKING:
    from wandb import Settings as WBSettings
    from wandb.sdk.data_types.trace_tree import Span
    from wandb.sdk.lib.paths import StrPath
    from wandb.wandb_run import Run as WBRun


PRINT_WARNINGS = True


def _serialize_io(run_inputs: Optional[dict]) -> dict:
    if not run_inputs:
        return {}
    from google.protobuf.json_format import MessageToJson
    from google.protobuf.message import Message

    serialized_inputs = {}
    for key, value in run_inputs.items():
        if isinstance(value, Message):
            serialized_inputs[key] = MessageToJson(value)
        elif key == "input_documents":
            serialized_inputs.update(
                {f"input_document_{i}": doc.json() for i, doc in enumerate(value)}
            )
        else:
            serialized_inputs[key] = value
    return serialized_inputs


[docs]class RunProcessor: """处理LangChain Runs转换为WBTraceTree。"""
[docs] def __init__(self, wandb_module: Any, trace_module: Any): self.wandb = wandb_module self.trace_tree = trace_module
[docs] def process_span(self, run: Run) -> Optional["Span"]: """将LangChain Run转换为W&B Trace Span。 :param run: 要转换的LangChain Run。 :return: 转换后的W&B Trace Span。 """ try: span = self._convert_lc_run_to_wb_span(run) return span except Exception as e: if PRINT_WARNINGS: self.wandb.termwarn( f"Skipping trace saving - unable to safely convert LangChain Run " f"into W&B Trace due to: {e}" ) return None
def _convert_run_to_wb_span(self, run: Run) -> "Span": """从运行中创建跨度的基本实用程序。 :param run: 要转换的运行。 :return: 转换后的跨度。 """ attributes = {**run.extra} if run.extra else {} attributes["execution_order"] = run.execution_order return self.trace_tree.Span( span_id=str(run.id) if run.id is not None else None, name=run.name, start_time_ms=int(run.start_time.timestamp() * 1000), end_time_ms=int(run.end_time.timestamp() * 1000) if run.end_time is not None else None, status_code=self.trace_tree.StatusCode.SUCCESS if run.error is None else self.trace_tree.StatusCode.ERROR, status_message=run.error, attributes=attributes, ) def _convert_llm_run_to_wb_span(self, run: Run) -> "Span": """将LangChain LLM运行转换为W&B Trace Span。 :param run: 要转换的LangChain LLM运行。 :return: 转换后的W&B Trace Span。 """ base_span = self._convert_run_to_wb_span(run) if base_span.attributes is None: base_span.attributes = {} base_span.attributes["llm_output"] = (run.outputs or {}).get("llm_output", {}) base_span.results = [ self.trace_tree.Result( inputs={"prompt": prompt}, outputs={ f"gen_{g_i}": gen["text"] for g_i, gen in enumerate(run.outputs["generations"][ndx]) } if ( run.outputs is not None and len(run.outputs["generations"]) > ndx and len(run.outputs["generations"][ndx]) > 0 ) else None, ) for ndx, prompt in enumerate(run.inputs["prompts"] or []) ] base_span.span_kind = self.trace_tree.SpanKind.LLM return base_span def _convert_chain_run_to_wb_span(self, run: Run) -> "Span": """将LangChain Chain Run转换为W&B Trace Span。 :param run: 要转换的LangChain Chain Run。 :return: 转换后的W&B Trace Span。 """ base_span = self._convert_run_to_wb_span(run) base_span.results = [ self.trace_tree.Result( inputs=_serialize_io(run.inputs), outputs=_serialize_io(run.outputs) ) ] base_span.child_spans = [ self._convert_lc_run_to_wb_span(child_run) for child_run in run.child_runs ] base_span.span_kind = ( self.trace_tree.SpanKind.AGENT if "agent" in run.name.lower() else self.trace_tree.SpanKind.CHAIN ) return base_span def _convert_tool_run_to_wb_span(self, run: Run) -> "Span": """将LangChain工具运行转换为W&B跟踪跨度。 :param run: 要转换的LangChain工具运行。 :return: 转换后的W&B跟踪跨度。 """ base_span = self._convert_run_to_wb_span(run) base_span.results = [ self.trace_tree.Result( inputs=_serialize_io(run.inputs), outputs=_serialize_io(run.outputs) ) ] base_span.child_spans = [ self._convert_lc_run_to_wb_span(child_run) for child_run in run.child_runs ] base_span.span_kind = self.trace_tree.SpanKind.TOOL return base_span def _convert_lc_run_to_wb_span(self, run: Run) -> "Span": """将任何通用LangChain Run转换为W&B Trace Span的实用工具。 :param run: 要转换的LangChain Run。 :return: 转换后的W&B Trace Span。 """ if run.run_type == "llm": return self._convert_llm_run_to_wb_span(run) elif run.run_type == "chain": return self._convert_chain_run_to_wb_span(run) elif run.run_type == "tool": return self._convert_tool_run_to_wb_span(run) else: return self._convert_run_to_wb_span(run)
[docs] def process_model(self, run: Run) -> Optional[Dict[str, Any]]: """处理用于wandb模型字典序列化的实用程序。 :param run: 要处理的运行。 :return: 转换为传递给WBTraceTree的model_dict。 """ try: data = json.loads(run.json()) processed = self.flatten_run(data) keep_keys = ( "id", "name", "serialized", "inputs", "outputs", "parent_run_id", "execution_order", ) processed = self.truncate_run_iterative(processed, keep_keys=keep_keys) exact_keys, partial_keys = ("lc", "type"), ("api_key",) processed = self.modify_serialized_iterative( processed, exact_keys=exact_keys, partial_keys=partial_keys ) output = self.build_tree(processed) return output except Exception as e: if PRINT_WARNINGS: self.wandb.termwarn(f"WARNING: Failed to serialize model: {e}") return None
[docs] def flatten_run(self, run: Dict[str, Any]) -> List[Dict[str, Any]]: """将嵌套运行对象展平为运行列表的实用程序。 :param run: 要展平的基本运行。 :return: 展平后的运行列表。 """ def flatten(child_runs: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """递归展开运行中子运行列表的实用程序。 :param child_runs: 要展开的子运行列表。 :return: 展开后的运行列表。 """ if child_runs is None: return [] result = [] for item in child_runs: child_runs = item.pop("child_runs", []) result.append(item) result.extend(flatten(child_runs)) return result return flatten([run])
[docs] def truncate_run_iterative( self, runs: List[Dict[str, Any]], keep_keys: Tuple[str, ...] = () ) -> List[Dict[str, Any]]: """实用程序,用于将一组运行字典截断,仅保留每个运行中指定的键。 :param runs: 要截断的运行列表。 :param keep_keys: 每个运行中要保留的键。 :return: 截断后的运行列表。 """ def truncate_single(run: Dict[str, Any]) -> Dict[str, Any]: """将单次运行字典截断为仅保留指定键的实用程序。 :param run: 要截断的运行字典。 :return: 截断后的运行字典 """ new_dict = {} for key in run: if key in keep_keys: new_dict[key] = run.get(key) return new_dict return list(map(truncate_single, runs))
[docs] def modify_serialized_iterative( self, runs: List[Dict[str, Any]], exact_keys: Tuple[str, ...] = (), partial_keys: Tuple[str, ...] = (), ) -> List[Dict[str, Any]]: """用于修改运行字典列表的序列化字段的实用程序。 删除与exact_keys匹配的任何键以及包含任何partial_keys的键。 递归地将kwargs键下的字典移动到顶层。 将"id"字段更改为告诉WBTraceTree如何可视化运行的字符串"_kind"字段。将"serialized"字段提升到顶层。 :param runs: 要修改的运行列表。 :param exact_keys: 要从序列化字段中删除的键的元组。 :param partial_keys: 要从序列化字段中删除的部分键的元组。 :return: 修改后的运行列表。 """ def remove_exact_and_partial_keys(obj: Dict[str, Any]) -> Dict[str, Any]: """递归地从字典中删除完全和部分键。 :param obj: 要从中删除键的字典。 :return: 修改后的字典。 """ if isinstance(obj, dict): obj = { k: v for k, v in obj.items() if k not in exact_keys and not any(partial in k for partial in partial_keys) } for k, v in obj.items(): obj[k] = remove_exact_and_partial_keys(v) elif isinstance(obj, list): obj = [remove_exact_and_partial_keys(x) for x in obj] return obj def handle_id_and_kwargs( obj: Dict[str, Any], root: bool = False ) -> Dict[str, Any]: """递归处理字典的id和kwargs字段。 将id字段更改为字符串"_kind"字段,告诉WBTraceTree如何可视化运行。 递归地将kwargs键下的字典移动到顶层。 :param obj: 具有id和kwargs字段的运行字典。 :param root: 是否为根字典或序列化字典。 :return: 修改后的字典。 """ if isinstance(obj, dict): if ("id" in obj or "name" in obj) and not root: _kind = obj.get("id") if not _kind: _kind = [obj.get("name")] obj["_kind"] = _kind[-1] obj.pop("id", None) obj.pop("name", None) if "kwargs" in obj: kwargs = obj.pop("kwargs") for k, v in kwargs.items(): obj[k] = v for k, v in obj.items(): obj[k] = handle_id_and_kwargs(v) elif isinstance(obj, list): obj = [handle_id_and_kwargs(x) for x in obj] return obj def transform_serialized(serialized: Dict[str, Any]) -> Dict[str, Any]: """将运行字典的序列化字段转换为与WBTraceTree兼容。 :param serialized: 运行字典的序列化字段。 :return: 转换后的序列化字段。 """ serialized = handle_id_and_kwargs(serialized, root=True) serialized = remove_exact_and_partial_keys(serialized) return serialized def transform_run(run: Dict[str, Any]) -> Dict[str, Any]: """将运行字典转换为与WBTraceTree兼容的格式。 :param run: 需要转换的运行字典。 :return: 转换后的运行字典。 """ transformed_dict = transform_serialized(run) serialized = transformed_dict.pop("serialized") for k, v in serialized.items(): transformed_dict[k] = v _kind = transformed_dict.get("_kind", None) name = transformed_dict.pop("name", None) exec_ord = transformed_dict.pop("execution_order", None) if not name: name = _kind output_dict = { f"{exec_ord}_{name}": transformed_dict, } return output_dict return list(map(transform_run, runs))
[docs] def build_tree(self, runs: List[Dict[str, Any]]) -> Dict[str, Any]: """从运行列表构建一个嵌套字典。 :param runs: 用于构建树的运行列表。 :return: 表示langchain Run的嵌套字典,其结构与WBTraceTree兼容。 """ id_to_data = {} child_to_parent = {} for entity in runs: for key, data in entity.items(): id_val = data.pop("id", None) parent_run_id = data.pop("parent_run_id", None) id_to_data[id_val] = {key: data} if parent_run_id: child_to_parent[id_val] = parent_run_id for child_id, parent_id in child_to_parent.items(): parent_dict = id_to_data[parent_id] parent_dict[next(iter(parent_dict))][ next(iter(id_to_data[child_id])) ] = id_to_data[child_id][next(iter(id_to_data[child_id]))] root_dict = next( data for id_val, data in id_to_data.items() if id_val not in child_to_parent ) return root_dict
[docs]class WandbRunArgs(TypedDict): """WandbTracer的参数。""" job_type: Optional[str] dir: Optional[StrPath] config: Union[Dict, str, None] project: Optional[str] entity: Optional[str] reinit: Optional[bool] tags: Optional[Sequence] group: Optional[str] name: Optional[str] notes: Optional[str] magic: Optional[Union[dict, str, bool]] config_exclude_keys: Optional[List[str]] config_include_keys: Optional[List[str]] anonymous: Optional[str] mode: Optional[str] allow_val_change: Optional[bool] resume: Optional[Union[bool, str]] force: Optional[bool] tensorboard: Optional[bool] sync_tensorboard: Optional[bool] monitor_gym: Optional[bool] save_code: Optional[bool] id: Optional[str] settings: Union[WBSettings, Dict[str, Any], None]
[docs]class WandbTracer(BaseTracer): """回调处理程序,用于记录到Weights and Biases。 该处理程序将记录模型架构和运行跟踪到Weights and Biases。 这将确保所有LangChain活动都被记录到W&B。""" _run: Optional[WBRun] = None _run_args: Optional[WandbRunArgs] = None
[docs] def __init__(self, run_args: Optional[WandbRunArgs] = None, **kwargs: Any) -> None: """初始化WandbTracer。 参数: run_args: (dict, optional) 传递给`wandb.init()`的参数。如果未提供,将以无参数调用`wandb.init()`。请参考`wandb.init`获取更多详细信息。 要使用W&B监视所有LangChain活动,请像其他LangChain回调一样添加此跟踪器: ``` from wandb.integration.langchain import WandbTracer tracer = WandbTracer() chain = LLMChain(llm, callbacks=[tracer]) # ...笔记本/脚本结束: tracer.finish() ``` """ super().__init__(**kwargs) try: import wandb from wandb.sdk.data_types import trace_tree except ImportError as e: raise ImportError( "Could not import wandb python package." "Please install it with `pip install -U wandb`." ) from e self._wandb = wandb self._trace_tree = trace_tree self._run_args = run_args self._ensure_run(should_print_url=(wandb.run is None)) self.run_processor = RunProcessor(self._wandb, self._trace_tree)
[docs] def finish(self) -> None: """等待所有异步进程完成和数据上传。 `wandb.finish()`的代理。 """ self._wandb.finish()
def _log_trace_from_run(self, run: Run) -> None: """将LangChain运行记录到W*B作为W&B跟踪。""" self._ensure_run() root_span = self.run_processor.process_span(run) model_dict = self.run_processor.process_model(run) if root_span is None: return model_trace = self._trace_tree.WBTraceTree( root_span=root_span, model_dict=model_dict, ) if self._wandb.run is not None: self._wandb.run.log({"langchain_trace": model_trace}) def _ensure_run(self, should_print_url: bool = False) -> None: """确保存在一个活跃的W&B运行。 如果不存在,则将使用提供的run_args启动一个新的运行。 """ if self._wandb.run is None: run_args: Dict = {**(self._run_args or {})} if "settings" not in run_args: run_args["settings"] = {"silent": True} self._wandb.init(**run_args) if self._wandb.run is not None: if should_print_url: run_url = self._wandb.run.settings.run_url self._wandb.termlog( f"Streaming LangChain activity to W&B at {run_url}\n" "`WandbTracer` is currently in beta.\n" "Please report any issues to " "https://github.com/wandb/wandb/issues with the tag " "`langchain`." ) self._wandb.run._label(repo="langchain") def _persist_run(self, run: "Run") -> None: """持久化一个运行。""" self._log_trace_from_run(run)