import asyncio
import base64
import re
from dataclasses import asdict
from typing import Literal, Optional

from langchain_core.runnables.graph import (


[docs] def draw_mermaid( nodes: dict[str, Node], edges: list[Edge], *, first_node: Optional[str] = None, last_node: Optional[str] = None, with_styles: bool = True, curve_style: CurveStyle = CurveStyle.LINEAR, node_styles: Optional[NodeStyles] = None, wrap_label_n_words: int = 9, ) -> str: """Draws a Mermaid graph using the provided graph data. Args: nodes (dict[str, str]): List of node ids. edges (List[Edge]): List of edges, object with a source, target and data. first_node (str, optional): Id of the first node. Defaults to None. last_node (str, optional): Id of the last node. Defaults to None. with_styles (bool, optional): Whether to include styles in the graph. Defaults to True. curve_style (CurveStyle, optional): Curve style for the edges. Defaults to CurveStyle.LINEAR. node_styles (NodeStyles, optional): Node colors for different types. Defaults to NodeStyles(). wrap_label_n_words (int, optional): Words to wrap the edge labels. Defaults to 9. Returns: str: Mermaid graph syntax. """ # Initialize Mermaid graph configuration mermaid_graph = ( ( f"%%{{init: {{'flowchart': {{'curve': '{curve_style.value}'" f"}}}}}}%%\ngraph TD;\n" ) if with_styles else "graph TD;\n" ) if with_styles: # Node formatting templates default_class_label = "default" format_dict = {default_class_label: "{0}({1})"} if first_node is not None: format_dict[first_node] = "{0}([{1}]):::first" if last_node is not None: format_dict[last_node] = "{0}([{1}]):::last" # Add nodes to the graph for key, node in nodes.items(): node_name =":")[-1] label = ( f"<p>{node_name}</p>" if node_name.startswith(tuple(MARKDOWN_SPECIAL_CHARS)) and node_name.endswith(tuple(MARKDOWN_SPECIAL_CHARS)) else node_name ) if node.metadata: label = ( f"{label}<hr/><small><em>" + "\n".join( f"{key} = {value}" for key, value in node.metadata.items() ) + "</em></small>" ) node_label = format_dict.get(key, format_dict[default_class_label]).format( _escape_node_label(key), label ) mermaid_graph += f"\t{node_label}\n" # Group edges by their common prefixes edge_groups: dict[str, list[Edge]] = {} for edge in edges: src_parts = edge.source.split(":") tgt_parts =":") common_prefix = ":".join( src for src, tgt in zip(src_parts, tgt_parts) if src == tgt ) edge_groups.setdefault(common_prefix, []).append(edge) seen_subgraphs = set() def add_subgraph(edges: list[Edge], prefix: str) -> None: nonlocal mermaid_graph self_loop = len(edges) == 1 and edges[0].source == edges[0].target if prefix and not self_loop: subgraph = prefix.split(":")[-1] if subgraph in seen_subgraphs: msg = ( f"Found duplicate subgraph '{subgraph}' -- this likely means that " "you're reusing a subgraph node with the same name. " "Please adjust your graph to have subgraph nodes with unique names." ) raise ValueError(msg) seen_subgraphs.add(subgraph) mermaid_graph += f"\tsubgraph {subgraph}\n" for edge in edges: source, target = edge.source, # Add BR every wrap_label_n_words words if is not None: edge_data = words = str(edge_data).split() # Split the string into words # Group words into chunks of wrap_label_n_words size if len(words) > wrap_label_n_words: edge_data = "&nbsp<br>&nbsp".join( " ".join(words[i : i + wrap_label_n_words]) for i in range(0, len(words), wrap_label_n_words) ) if edge.conditional: edge_label = f" -. &nbsp;{edge_data}&nbsp; .-> " else: edge_label = f" -- &nbsp;{edge_data}&nbsp; --> " else: edge_label = " -.-> " if edge.conditional else " --> " mermaid_graph += ( f"\t{_escape_node_label(source)}{edge_label}" f"{_escape_node_label(target)};\n" ) # Recursively add nested subgraphs for nested_prefix in edge_groups: if not nested_prefix.startswith(prefix + ":") or nested_prefix == prefix: continue add_subgraph(edge_groups[nested_prefix], nested_prefix) if prefix and not self_loop: mermaid_graph += "\tend\n" # Start with the top-level edges (no common prefix) add_subgraph(edge_groups.get("", []), "") # Add remaining subgraphs for prefix in edge_groups: if ":" in prefix or prefix == "": continue add_subgraph(edge_groups[prefix], prefix) # Add custom styles for nodes if with_styles: mermaid_graph += _generate_mermaid_graph_styles(node_styles or NodeStyles()) return mermaid_graph
def _escape_node_label(node_label: str) -> str: """Escapes the node label for Mermaid syntax.""" return re.sub(r"[^a-zA-Z-_0-9]", "_", node_label) def _generate_mermaid_graph_styles(node_colors: NodeStyles) -> str: """Generates Mermaid graph styles for different node types.""" styles = "" for class_name, style in asdict(node_colors).items(): styles += f"\tclassDef {class_name} {style}\n" return styles
[docs] def draw_mermaid_png( mermaid_syntax: str, output_file_path: Optional[str] = None, draw_method: MermaidDrawMethod = MermaidDrawMethod.API, background_color: Optional[str] = "white", padding: int = 10, ) -> bytes: """Draws a Mermaid graph as PNG using provided syntax. Args: mermaid_syntax (str): Mermaid graph syntax. output_file_path (str, optional): Path to save the PNG image. Defaults to None. draw_method (MermaidDrawMethod, optional): Method to draw the graph. Defaults to MermaidDrawMethod.API. background_color (str, optional): Background color of the image. Defaults to "white". padding (int, optional): Padding around the image. Defaults to 10. Returns: bytes: PNG image bytes. Raises: ValueError: If an invalid draw method is provided. """ if draw_method == MermaidDrawMethod.PYPPETEER: import asyncio img_bytes = _render_mermaid_using_pyppeteer( mermaid_syntax, output_file_path, background_color, padding ) ) elif draw_method == MermaidDrawMethod.API: img_bytes = _render_mermaid_using_api( mermaid_syntax, output_file_path, background_color ) else: supported_methods = ", ".join([m.value for m in MermaidDrawMethod]) msg = ( f"Invalid draw method: {draw_method}. " f"Supported draw methods are: {supported_methods}" ) raise ValueError(msg) return img_bytes
async def _render_mermaid_using_pyppeteer( mermaid_syntax: str, output_file_path: Optional[str] = None, background_color: Optional[str] = "white", padding: int = 10, device_scale_factor: int = 3, ) -> bytes: """Renders Mermaid graph using Pyppeteer.""" try: from pyppeteer import launch # type: ignore[import] except ImportError as e: msg = "Install Pyppeteer to use the Pyppeteer method: `pip install pyppeteer`." raise ImportError(msg) from e browser = await launch() page = await browser.newPage() # Setup Mermaid JS await page.goto("about:blank") await page.addScriptTag( {"url": ""} ) await page.evaluate( """() => { mermaid.initialize({startOnLoad:true}); }""" ) # Render SVG svg_code = await page.evaluate( """(mermaidGraph) => { return mermaid.mermaidAPI.render('mermaid', mermaidGraph); }""", mermaid_syntax, ) # Set the page background to white await page.evaluate( """(svg, background_color) => { document.body.innerHTML = svg; = background_color; }""", svg_code["svg"], background_color, ) # Take a screenshot dimensions = await page.evaluate( """() => { const svgElement = document.querySelector('svg'); const rect = svgElement.getBoundingClientRect(); return { width: rect.width, height: rect.height }; }""" ) await page.setViewport( { "width": int(dimensions["width"] + padding), "height": int(dimensions["height"] + padding), "deviceScaleFactor": device_scale_factor, } ) img_bytes = await page.screenshot({"fullPage": False}) await browser.close() def write_to_file(path: str, bytes: bytes) -> None: with open(path, "wb") as file: file.write(bytes) if output_file_path is not None: await asyncio.get_event_loop().run_in_executor( None, write_to_file, output_file_path, img_bytes ) return img_bytes def _render_mermaid_using_api( mermaid_syntax: str, output_file_path: Optional[str] = None, background_color: Optional[str] = "white", file_type: Optional[Literal["jpeg", "png", "webp"]] = "png", ) -> bytes: """Renders Mermaid graph using the Mermaid.INK API.""" try: import requests # type: ignore[import] except ImportError as e: msg = ( "Install the `requests` module to use the Mermaid.INK API: " "`pip install requests`." ) raise ImportError(msg) from e # Use Mermaid API to render the image mermaid_syntax_encoded = base64.b64encode(mermaid_syntax.encode("utf8")).decode( "ascii" ) # Check if the background color is a hexadecimal color code using regex if background_color is not None: hex_color_pattern = re.compile(r"^#(?:[0-9a-fA-F]{3}){1,2}$") if not hex_color_pattern.match(background_color): background_color = f"!{background_color}" image_url = ( f"{mermaid_syntax_encoded}" f"?type={file_type}&bgColor={background_color}" ) response = requests.get(image_url, timeout=10) if response.status_code == 200: img_bytes = response.content if output_file_path is not None: with open(output_file_path, "wb") as file: file.write(response.content) return img_bytes else: msg = ( f"Failed to render the graph using the Mermaid.INK API. " f"Status code: {response.status_code}." ) raise ValueError(msg)