Source code for langchain_core.runnables.graph

from __future__ import annotations

import inspect
from dataclasses import dataclass, field
from enum import Enum
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Dict,
    List,
    NamedTuple,
    Optional,
    Tuple,
    Type,
    TypedDict,
    Union,
    overload,
)
from uuid import UUID, uuid4

from langchain_core.pydantic_v1 import BaseModel

if TYPE_CHECKING:
    from langchain_core.runnables.base import Runnable as RunnableType


[docs]class LabelsDict(TypedDict): """图中节点和边的标签字典。""" nodes: dict[str, str] """节点的标签。""" edges: dict[str, str] """边的标签。"""
[docs]def is_uuid(value: str) -> bool: """检查字符串是否是有效的UUID。 参数: value:要检查的字符串。 返回: 如果字符串是有效的UUID,则返回True,否则返回False。 """ try: UUID(value) return True except ValueError: return False
[docs]class Edge(NamedTuple): """图中的边。""" source: str target: str data: Optional[str] = None conditional: bool = False
[docs]class Node(NamedTuple): """图中的节点。""" id: str data: Union[Type[BaseModel], RunnableType]
[docs]class Branch(NamedTuple): """图中的分支。""" condition: Callable[..., str] ends: Optional[dict[str, str]]
[docs]class CurveStyle(Enum): """Mermaid支持的不同曲线样式的枚举""" BASIS = "basis" BUMP_X = "bumpX" BUMP_Y = "bumpY" CARDINAL = "cardinal" CATMULL_ROM = "catmullRom" LINEAR = "linear" MONOTONE_X = "monotoneX" MONOTONE_Y = "monotoneY" NATURAL = "natural" STEP = "step" STEP_AFTER = "stepAfter" STEP_BEFORE = "stepBefore"
[docs]@dataclass class NodeColors: """不同节点类型的十六进制颜色代码架构""" start: str = "#ffdfba" end: str = "#baffc9" other: str = "#fad7de"
[docs]class MermaidDrawMethod(Enum): """Mermaid支持的不同绘图方法的枚举""" PYPPETEER = "pyppeteer" # Uses Pyppeteer to render the graph API = "api" # Uses Mermaid.INK API to render the graph
[docs]def node_data_str(node: Node) -> str: """将节点的数据转换为字符串。 参数: node:要转换的节点。 返回: 数据的字符串表示形式。 """ from langchain_core.runnables.base import Runnable if not is_uuid(node.id): return node.id elif isinstance(node.data, Runnable): try: data = str(node.data) if ( data.startswith("<") or data[0] != data[0].upper() or len(data.splitlines()) > 1 ): data = node.data.__class__.__name__ elif len(data) > 42: data = data[:42] + "..." except Exception: data = node.data.__class__.__name__ else: data = node.data.__name__ return data if not data.startswith("Runnable") else data[8:]
[docs]def node_data_json( node: Node, *, with_schemas: bool = False ) -> Dict[str, Union[str, Dict[str, Any]]]: """将节点的数据转换为可JSON序列化的格式。 参数: node:要转换的节点。 with_schemas:如果数据是Pydantic模型,则是否包含数据的模式。 返回: 包含数据类型和数据本身的字典。 """ from langchain_core.load.serializable import to_json_not_implemented from langchain_core.runnables.base import Runnable, RunnableSerializable if isinstance(node.data, RunnableSerializable): return { "type": "runnable", "data": { "id": node.data.lc_id(), "name": node.data.get_name(), }, } elif isinstance(node.data, Runnable): return { "type": "runnable", "data": { "id": to_json_not_implemented(node.data)["id"], "name": node.data.get_name(), }, } elif inspect.isclass(node.data) and issubclass(node.data, BaseModel): return ( { "type": "schema", "data": node.data.schema(), } if with_schemas else { "type": "schema", "data": node_data_str(node), } ) else: return { "type": "unknown", "data": node_data_str(node), }
[docs]@dataclass class Graph: """节点和边的图。""" nodes: Dict[str, Node] = field(default_factory=dict) edges: List[Edge] = field(default_factory=list)
[docs] def to_json(self, *, with_schemas: bool = False) -> Dict[str, List[Dict[str, Any]]]: """将图形转换为可以序列化为JSON的格式。""" stable_node_ids = { node.id: i if is_uuid(node.id) else node.id for i, node in enumerate(self.nodes.values()) } edges: List[Dict[str, Any]] = [] for edge in self.edges: edge_dict = { "source": stable_node_ids[edge.source], "target": stable_node_ids[edge.target], } if edge.data is not None: edge_dict["data"] = edge.data if edge.conditional: edge_dict["conditional"] = True edges.append(edge_dict) return { "nodes": [ { "id": stable_node_ids[node.id], **node_data_json(node, with_schemas=with_schemas), } for node in self.nodes.values() ], "edges": edges, }
def __bool__(self) -> bool: return bool(self.nodes)
[docs] def next_id(self) -> str: return uuid4().hex
[docs] def add_node( self, data: Union[Type[BaseModel], RunnableType], id: Optional[str] = None ) -> Node: """向图中添加一个节点并返回它。""" if id is not None and id in self.nodes: raise ValueError(f"Node with id {id} already exists") node = Node(id=id or self.next_id(), data=data) self.nodes[node.id] = node return node
[docs] def remove_node(self, node: Node) -> None: """从图中删除一个节点以及与其相连的所有边。""" self.nodes.pop(node.id) self.edges = [ edge for edge in self.edges if edge.source != node.id and edge.target != node.id ]
[docs] def add_edge( self, source: Node, target: Node, data: Optional[str] = None, conditional: bool = False, ) -> Edge: """向图中添加一条边并返回它。""" if source.id not in self.nodes: raise ValueError(f"Source node {source.id} not in graph") if target.id not in self.nodes: raise ValueError(f"Target node {target.id} not in graph") edge = Edge( source=source.id, target=target.id, data=data, conditional=conditional ) self.edges.append(edge) return edge
[docs] def extend( self, graph: Graph, *, prefix: str = "" ) -> Tuple[Optional[Node], Optional[Node]]: """将另一个图中的所有节点和边添加进来。 注意这不会检查重复项,也不会连接这两个图。 """ if all(is_uuid(node.id) for node in graph.nodes.values()): prefix = "" def prefixed(id: str) -> str: return f"{prefix}:{id}" if prefix else id # prefix each node self.nodes.update( {prefixed(k): Node(prefixed(k), v.data) for k, v in graph.nodes.items()} ) # prefix each edge's source and target self.edges.extend( [ Edge( prefixed(edge.source), prefixed(edge.target), edge.data, edge.conditional, ) for edge in graph.edges ] ) # return (prefixed) first and last nodes of the subgraph first, last = graph.first_node(), graph.last_node() return ( Node(prefixed(first.id), first.data) if first else None, Node(prefixed(last.id), last.data) if last else None, )
[docs] def first_node(self) -> Optional[Node]: """找到不是任何边的目标节点。 如果没有这样的节点,或者有多个,则返回None。 在绘制图形时,该节点将是起点。 """ targets = {edge.target for edge in self.edges} found: List[Node] = [] for node in self.nodes.values(): if node.id not in targets: found.append(node) return found[0] if len(found) == 1 else None
[docs] def last_node(self) -> Optional[Node]: """找到不是任何边的源节点的单个节点。 如果没有这样的节点,或者有多个节点,则返回None。 在绘制图形时,该节点将是目的地。 """ sources = {edge.source for edge in self.edges} found: List[Node] = [] for node in self.nodes.values(): if node.id not in sources: found.append(node) return found[0] if len(found) == 1 else None
[docs] def trim_first_node(self) -> None: """如果存在且只有一个出边,则移除第一个节点,即移除它不会使图中没有“第一个”节点。 """ first_node = self.first_node() if first_node: if ( len(self.nodes) == 1 or len([edge for edge in self.edges if edge.source == first_node.id]) == 1 ): self.remove_node(first_node)
[docs] def trim_last_node(self) -> None: """如果存在且只有一个入边,则删除最后一个节点,即删除它不会使图中没有“最后”节点。 """ last_node = self.last_node() if last_node: if ( len(self.nodes) == 1 or len([edge for edge in self.edges if edge.target == last_node.id]) == 1 ): self.remove_node(last_node)
[docs] def draw_ascii(self) -> str: from langchain_core.runnables.graph_ascii import draw_ascii return draw_ascii( {node.id: node_data_str(node) for node in self.nodes.values()}, self.edges, )
[docs] def print_ascii(self) -> None: print(self.draw_ascii()) # noqa: T201
@overload def draw_png( self, output_file_path: str, fontname: Optional[str] = None, labels: Optional[LabelsDict] = None, ) -> None: ... @overload def draw_png( self, output_file_path: None, fontname: Optional[str] = None, labels: Optional[LabelsDict] = None, ) -> bytes: ...
[docs] def draw_png( self, output_file_path: Optional[str] = None, fontname: Optional[str] = None, labels: Optional[LabelsDict] = None, ) -> Union[bytes, None]: from langchain_core.runnables.graph_png import PngDrawer default_node_labels = { node.id: node_data_str(node) for node in self.nodes.values() } return PngDrawer( fontname, LabelsDict( nodes={ **default_node_labels, **(labels["nodes"] if labels is not None else {}), }, edges=labels["edges"] if labels is not None else {}, ), ).draw(self, output_file_path)
[docs] def draw_mermaid( self, *, with_styles: bool = True, curve_style: CurveStyle = CurveStyle.LINEAR, node_colors: NodeColors = NodeColors( start="#ffdfba", end="#baffc9", other="#fad7de" ), wrap_label_n_words: int = 9, ) -> str: from langchain_core.runnables.graph_mermaid import draw_mermaid nodes = {node.id: node_data_str(node) for node in self.nodes.values()} first_node = self.first_node() first_label = node_data_str(first_node) if first_node is not None else None last_node = self.last_node() last_label = node_data_str(last_node) if last_node is not None else None return draw_mermaid( nodes=nodes, edges=self.edges, first_node_label=first_label, last_node_label=last_label, with_styles=with_styles, curve_style=curve_style, node_colors=node_colors, wrap_label_n_words=wrap_label_n_words, )
[docs] def draw_mermaid_png( self, *, curve_style: CurveStyle = CurveStyle.LINEAR, node_colors: NodeColors = NodeColors( start="#ffdfba", end="#baffc9", other="#fad7de" ), wrap_label_n_words: int = 9, output_file_path: Optional[str] = None, draw_method: MermaidDrawMethod = MermaidDrawMethod.API, background_color: str = "white", padding: int = 10, ) -> bytes: from langchain_core.runnables.graph_mermaid import draw_mermaid_png mermaid_syntax = self.draw_mermaid( curve_style=curve_style, node_colors=node_colors, wrap_label_n_words=wrap_label_n_words, ) return draw_mermaid_png( mermaid_syntax=mermaid_syntax, output_file_path=output_file_path, draw_method=draw_method, background_color=background_color, padding=padding, )