"""一个记录活动到Weights & Biases的跟踪器实现。"""
from __future__ import annotations
import json
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Optional,
Sequence,
Tuple,
TypedDict,
Union,
)
from langchain_core.tracers.base import BaseTracer
from langchain_core.tracers.schemas import Run
if TYPE_CHECKING:
from wandb import Settings as WBSettings
from wandb.sdk.data_types.trace_tree import Span
from wandb.sdk.lib.paths import StrPath
from wandb.wandb_run import Run as WBRun
PRINT_WARNINGS = True
def _serialize_io(run_inputs: Optional[dict]) -> dict:
if not run_inputs:
return {}
from google.protobuf.json_format import MessageToJson
from google.protobuf.message import Message
serialized_inputs = {}
for key, value in run_inputs.items():
if isinstance(value, Message):
serialized_inputs[key] = MessageToJson(value)
elif key == "input_documents":
serialized_inputs.update(
{f"input_document_{i}": doc.json() for i, doc in enumerate(value)}
)
else:
serialized_inputs[key] = value
return serialized_inputs
[docs]class RunProcessor:
"""处理LangChain Runs转换为WBTraceTree。"""
[docs] def __init__(self, wandb_module: Any, trace_module: Any):
self.wandb = wandb_module
self.trace_tree = trace_module
[docs] def process_span(self, run: Run) -> Optional["Span"]:
"""将LangChain Run转换为W&B Trace Span。
:param run: 要转换的LangChain Run。
:return: 转换后的W&B Trace Span。
"""
try:
span = self._convert_lc_run_to_wb_span(run)
return span
except Exception as e:
if PRINT_WARNINGS:
self.wandb.termwarn(
f"Skipping trace saving - unable to safely convert LangChain Run "
f"into W&B Trace due to: {e}"
)
return None
def _convert_run_to_wb_span(self, run: Run) -> "Span":
"""从运行中创建跨度的基本实用程序。
:param run: 要转换的运行。
:return: 转换后的跨度。
"""
attributes = {**run.extra} if run.extra else {}
attributes["execution_order"] = run.execution_order
return self.trace_tree.Span(
span_id=str(run.id) if run.id is not None else None,
name=run.name,
start_time_ms=int(run.start_time.timestamp() * 1000),
end_time_ms=int(run.end_time.timestamp() * 1000)
if run.end_time is not None
else None,
status_code=self.trace_tree.StatusCode.SUCCESS
if run.error is None
else self.trace_tree.StatusCode.ERROR,
status_message=run.error,
attributes=attributes,
)
def _convert_llm_run_to_wb_span(self, run: Run) -> "Span":
"""将LangChain LLM运行转换为W&B Trace Span。
:param run: 要转换的LangChain LLM运行。
:return: 转换后的W&B Trace Span。
"""
base_span = self._convert_run_to_wb_span(run)
if base_span.attributes is None:
base_span.attributes = {}
base_span.attributes["llm_output"] = (run.outputs or {}).get("llm_output", {})
base_span.results = [
self.trace_tree.Result(
inputs={"prompt": prompt},
outputs={
f"gen_{g_i}": gen["text"]
for g_i, gen in enumerate(run.outputs["generations"][ndx])
}
if (
run.outputs is not None
and len(run.outputs["generations"]) > ndx
and len(run.outputs["generations"][ndx]) > 0
)
else None,
)
for ndx, prompt in enumerate(run.inputs["prompts"] or [])
]
base_span.span_kind = self.trace_tree.SpanKind.LLM
return base_span
def _convert_chain_run_to_wb_span(self, run: Run) -> "Span":
"""将LangChain Chain Run转换为W&B Trace Span。
:param run: 要转换的LangChain Chain Run。
:return: 转换后的W&B Trace Span。
"""
base_span = self._convert_run_to_wb_span(run)
base_span.results = [
self.trace_tree.Result(
inputs=_serialize_io(run.inputs), outputs=_serialize_io(run.outputs)
)
]
base_span.child_spans = [
self._convert_lc_run_to_wb_span(child_run) for child_run in run.child_runs
]
base_span.span_kind = (
self.trace_tree.SpanKind.AGENT
if "agent" in run.name.lower()
else self.trace_tree.SpanKind.CHAIN
)
return base_span
def _convert_tool_run_to_wb_span(self, run: Run) -> "Span":
"""将LangChain工具运行转换为W&B跟踪跨度。
:param run: 要转换的LangChain工具运行。
:return: 转换后的W&B跟踪跨度。
"""
base_span = self._convert_run_to_wb_span(run)
base_span.results = [
self.trace_tree.Result(
inputs=_serialize_io(run.inputs), outputs=_serialize_io(run.outputs)
)
]
base_span.child_spans = [
self._convert_lc_run_to_wb_span(child_run) for child_run in run.child_runs
]
base_span.span_kind = self.trace_tree.SpanKind.TOOL
return base_span
def _convert_lc_run_to_wb_span(self, run: Run) -> "Span":
"""将任何通用LangChain Run转换为W&B Trace Span的实用工具。
:param run: 要转换的LangChain Run。
:return: 转换后的W&B Trace Span。
"""
if run.run_type == "llm":
return self._convert_llm_run_to_wb_span(run)
elif run.run_type == "chain":
return self._convert_chain_run_to_wb_span(run)
elif run.run_type == "tool":
return self._convert_tool_run_to_wb_span(run)
else:
return self._convert_run_to_wb_span(run)
[docs] def process_model(self, run: Run) -> Optional[Dict[str, Any]]:
"""处理用于wandb模型字典序列化的实用程序。
:param run: 要处理的运行。
:return: 转换为传递给WBTraceTree的model_dict。
"""
try:
data = json.loads(run.json())
processed = self.flatten_run(data)
keep_keys = (
"id",
"name",
"serialized",
"inputs",
"outputs",
"parent_run_id",
"execution_order",
)
processed = self.truncate_run_iterative(processed, keep_keys=keep_keys)
exact_keys, partial_keys = ("lc", "type"), ("api_key",)
processed = self.modify_serialized_iterative(
processed, exact_keys=exact_keys, partial_keys=partial_keys
)
output = self.build_tree(processed)
return output
except Exception as e:
if PRINT_WARNINGS:
self.wandb.termwarn(f"WARNING: Failed to serialize model: {e}")
return None
[docs] def flatten_run(self, run: Dict[str, Any]) -> List[Dict[str, Any]]:
"""将嵌套运行对象展平为运行列表的实用程序。
:param run: 要展平的基本运行。
:return: 展平后的运行列表。
"""
def flatten(child_runs: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""递归展开运行中子运行列表的实用程序。
:param child_runs: 要展开的子运行列表。
:return: 展开后的运行列表。
"""
if child_runs is None:
return []
result = []
for item in child_runs:
child_runs = item.pop("child_runs", [])
result.append(item)
result.extend(flatten(child_runs))
return result
return flatten([run])
[docs] def truncate_run_iterative(
self, runs: List[Dict[str, Any]], keep_keys: Tuple[str, ...] = ()
) -> List[Dict[str, Any]]:
"""实用程序,用于将一组运行字典截断,仅保留每个运行中指定的键。
:param runs: 要截断的运行列表。
:param keep_keys: 每个运行中要保留的键。
:return: 截断后的运行列表。
"""
def truncate_single(run: Dict[str, Any]) -> Dict[str, Any]:
"""将单次运行字典截断为仅保留指定键的实用程序。
:param run: 要截断的运行字典。
:return: 截断后的运行字典
"""
new_dict = {}
for key in run:
if key in keep_keys:
new_dict[key] = run.get(key)
return new_dict
return list(map(truncate_single, runs))
[docs] def modify_serialized_iterative(
self,
runs: List[Dict[str, Any]],
exact_keys: Tuple[str, ...] = (),
partial_keys: Tuple[str, ...] = (),
) -> List[Dict[str, Any]]:
"""用于修改运行字典列表的序列化字段的实用程序。
删除与exact_keys匹配的任何键以及包含任何partial_keys的键。
递归地将kwargs键下的字典移动到顶层。
将"id"字段更改为告诉WBTraceTree如何可视化运行的字符串"_kind"字段。将"serialized"字段提升到顶层。
:param runs: 要修改的运行列表。
:param exact_keys: 要从序列化字段中删除的键的元组。
:param partial_keys: 要从序列化字段中删除的部分键的元组。
:return: 修改后的运行列表。
"""
def remove_exact_and_partial_keys(obj: Dict[str, Any]) -> Dict[str, Any]:
"""递归地从字典中删除完全和部分键。
:param obj: 要从中删除键的字典。
:return: 修改后的字典。
"""
if isinstance(obj, dict):
obj = {
k: v
for k, v in obj.items()
if k not in exact_keys
and not any(partial in k for partial in partial_keys)
}
for k, v in obj.items():
obj[k] = remove_exact_and_partial_keys(v)
elif isinstance(obj, list):
obj = [remove_exact_and_partial_keys(x) for x in obj]
return obj
def handle_id_and_kwargs(
obj: Dict[str, Any], root: bool = False
) -> Dict[str, Any]:
"""递归处理字典的id和kwargs字段。
将id字段更改为字符串"_kind"字段,告诉WBTraceTree如何可视化运行。
递归地将kwargs键下的字典移动到顶层。
:param obj: 具有id和kwargs字段的运行字典。
:param root: 是否为根字典或序列化字典。
:return: 修改后的字典。
"""
if isinstance(obj, dict):
if ("id" in obj or "name" in obj) and not root:
_kind = obj.get("id")
if not _kind:
_kind = [obj.get("name")]
obj["_kind"] = _kind[-1]
obj.pop("id", None)
obj.pop("name", None)
if "kwargs" in obj:
kwargs = obj.pop("kwargs")
for k, v in kwargs.items():
obj[k] = v
for k, v in obj.items():
obj[k] = handle_id_and_kwargs(v)
elif isinstance(obj, list):
obj = [handle_id_and_kwargs(x) for x in obj]
return obj
def transform_serialized(serialized: Dict[str, Any]) -> Dict[str, Any]:
"""将运行字典的序列化字段转换为与WBTraceTree兼容。
:param serialized: 运行字典的序列化字段。
:return: 转换后的序列化字段。
"""
serialized = handle_id_and_kwargs(serialized, root=True)
serialized = remove_exact_and_partial_keys(serialized)
return serialized
def transform_run(run: Dict[str, Any]) -> Dict[str, Any]:
"""将运行字典转换为与WBTraceTree兼容的格式。
:param run: 需要转换的运行字典。
:return: 转换后的运行字典。
"""
transformed_dict = transform_serialized(run)
serialized = transformed_dict.pop("serialized")
for k, v in serialized.items():
transformed_dict[k] = v
_kind = transformed_dict.get("_kind", None)
name = transformed_dict.pop("name", None)
exec_ord = transformed_dict.pop("execution_order", None)
if not name:
name = _kind
output_dict = {
f"{exec_ord}_{name}": transformed_dict,
}
return output_dict
return list(map(transform_run, runs))
[docs] def build_tree(self, runs: List[Dict[str, Any]]) -> Dict[str, Any]:
"""从运行列表构建一个嵌套字典。
:param runs: 用于构建树的运行列表。
:return: 表示langchain Run的嵌套字典,其结构与WBTraceTree兼容。
"""
id_to_data = {}
child_to_parent = {}
for entity in runs:
for key, data in entity.items():
id_val = data.pop("id", None)
parent_run_id = data.pop("parent_run_id", None)
id_to_data[id_val] = {key: data}
if parent_run_id:
child_to_parent[id_val] = parent_run_id
for child_id, parent_id in child_to_parent.items():
parent_dict = id_to_data[parent_id]
parent_dict[next(iter(parent_dict))][
next(iter(id_to_data[child_id]))
] = id_to_data[child_id][next(iter(id_to_data[child_id]))]
root_dict = next(
data for id_val, data in id_to_data.items() if id_val not in child_to_parent
)
return root_dict
[docs]class WandbRunArgs(TypedDict):
"""WandbTracer的参数。"""
job_type: Optional[str]
dir: Optional[StrPath]
config: Union[Dict, str, None]
project: Optional[str]
entity: Optional[str]
reinit: Optional[bool]
tags: Optional[Sequence]
group: Optional[str]
name: Optional[str]
notes: Optional[str]
magic: Optional[Union[dict, str, bool]]
config_exclude_keys: Optional[List[str]]
config_include_keys: Optional[List[str]]
anonymous: Optional[str]
mode: Optional[str]
allow_val_change: Optional[bool]
resume: Optional[Union[bool, str]]
force: Optional[bool]
tensorboard: Optional[bool]
sync_tensorboard: Optional[bool]
monitor_gym: Optional[bool]
save_code: Optional[bool]
id: Optional[str]
settings: Union[WBSettings, Dict[str, Any], None]
[docs]class WandbTracer(BaseTracer):
"""回调处理程序,用于记录到Weights and Biases。
该处理程序将记录模型架构和运行跟踪到Weights and Biases。
这将确保所有LangChain活动都被记录到W&B。"""
_run: Optional[WBRun] = None
_run_args: Optional[WandbRunArgs] = None
[docs] def __init__(self, run_args: Optional[WandbRunArgs] = None, **kwargs: Any) -> None:
"""初始化WandbTracer。
参数:
run_args: (dict, optional) 传递给`wandb.init()`的参数。如果未提供,将以无参数调用`wandb.init()`。请参考`wandb.init`获取更多详细信息。
要使用W&B监视所有LangChain活动,请像其他LangChain回调一样添加此跟踪器:
```
from wandb.integration.langchain import WandbTracer
tracer = WandbTracer()
chain = LLMChain(llm, callbacks=[tracer])
# ...笔记本/脚本结束:
tracer.finish()
```
"""
super().__init__(**kwargs)
try:
import wandb
from wandb.sdk.data_types import trace_tree
except ImportError as e:
raise ImportError(
"Could not import wandb python package."
"Please install it with `pip install -U wandb`."
) from e
self._wandb = wandb
self._trace_tree = trace_tree
self._run_args = run_args
self._ensure_run(should_print_url=(wandb.run is None))
self.run_processor = RunProcessor(self._wandb, self._trace_tree)
[docs] def finish(self) -> None:
"""等待所有异步进程完成和数据上传。
`wandb.finish()`的代理。
"""
self._wandb.finish()
def _log_trace_from_run(self, run: Run) -> None:
"""将LangChain运行记录到W*B作为W&B跟踪。"""
self._ensure_run()
root_span = self.run_processor.process_span(run)
model_dict = self.run_processor.process_model(run)
if root_span is None:
return
model_trace = self._trace_tree.WBTraceTree(
root_span=root_span,
model_dict=model_dict,
)
if self._wandb.run is not None:
self._wandb.run.log({"langchain_trace": model_trace})
def _ensure_run(self, should_print_url: bool = False) -> None:
"""确保存在一个活跃的W&B运行。
如果不存在,则将使用提供的run_args启动一个新的运行。
"""
if self._wandb.run is None:
run_args: Dict = {**(self._run_args or {})}
if "settings" not in run_args:
run_args["settings"] = {"silent": True}
self._wandb.init(**run_args)
if self._wandb.run is not None:
if should_print_url:
run_url = self._wandb.run.settings.run_url
self._wandb.termlog(
f"Streaming LangChain activity to W&B at {run_url}\n"
"`WandbTracer` is currently in beta.\n"
"Please report any issues to "
"https://github.com/wandb/wandb/issues with the tag "
"`langchain`."
)
self._wandb.run._label(repo="langchain")
def _persist_run(self, run: "Run") -> None:
"""持久化一个运行。"""
self._log_trace_from_run(run)