Source code for langchain_core.runnables.utils

"""运行代码的实用工具。"""

from __future__ import annotations

import ast
import asyncio
import inspect
import textwrap
from functools import lru_cache
from inspect import signature
from itertools import groupby
from typing import (
    Any,
    AsyncIterable,
    AsyncIterator,
    Awaitable,
    Callable,
    Coroutine,
    Dict,
    Iterable,
    List,
    Mapping,
    NamedTuple,
    Optional,
    Protocol,
    Sequence,
    Set,
    Type,
    TypeVar,
    Union,
)

from typing_extensions import TypeGuard

from langchain_core.pydantic_v1 import BaseConfig, BaseModel
from langchain_core.pydantic_v1 import create_model as _create_model_base
from langchain_core.runnables.schema import StreamEvent

Input = TypeVar("Input", contravariant=True)
# Output type should implement __concat__, as eg str, list, dict do
Output = TypeVar("Output", covariant=True)


[docs]async def gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any: """使用信号量运行一个协程。 参数: semaphore:要使用的信号量。 coro:要运行的协程。 返回: 协程的结果。 """ async with semaphore: return await coro
[docs]async def gather_with_concurrency(n: Union[int, None], *coros: Coroutine) -> list: """收集具有并发协程数量限制的协程。 参数: n: 要同时运行的协程数量。 coros: 要运行的协程。 返回: 协程的结果。 """ if n is None: return await asyncio.gather(*coros) semaphore = asyncio.Semaphore(n) return await asyncio.gather(*(gated_coro(semaphore, c) for c in coros))
[docs]def accepts_run_manager(callable: Callable[..., Any]) -> bool: """检查可调用函数是否接受 run_manager 参数。""" try: return signature(callable).parameters.get("run_manager") is not None except ValueError: return False
[docs]def accepts_config(callable: Callable[..., Any]) -> bool: """检查可调用对象是否接受配置参数。""" try: return signature(callable).parameters.get("config") is not None except ValueError: return False
[docs]def accepts_context(callable: Callable[..., Any]) -> bool: """检查可调用函数是否接受上下文参数。""" try: return signature(callable).parameters.get("context") is not None except ValueError: return False
[docs]class IsLocalDict(ast.NodeVisitor): """检查一个名称是否是本地字典。"""
[docs] def __init__(self, name: str, keys: Set[str]) -> None: self.name = name self.keys = keys
[docs] def visit_Subscript(self, node: ast.Subscript) -> Any: if ( isinstance(node.ctx, ast.Load) and isinstance(node.value, ast.Name) and node.value.id == self.name and isinstance(node.slice, ast.Constant) and isinstance(node.slice.value, str) ): # we've found a subscript access on the name we're looking for self.keys.add(node.slice.value)
[docs] def visit_Call(self, node: ast.Call) -> Any: if ( isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Name) and node.func.value.id == self.name and node.func.attr == "get" and len(node.args) in (1, 2) and isinstance(node.args[0], ast.Constant) and isinstance(node.args[0].value, str) ): # we've found a .get() call on the name we're looking for self.keys.add(node.args[0].value)
[docs]class IsFunctionArgDict(ast.NodeVisitor): """检查函数的第一个参数是否为字典。"""
[docs] def __init__(self) -> None: self.keys: Set[str] = set()
[docs] def visit_Lambda(self, node: ast.Lambda) -> Any: if not node.args.args: return input_arg_name = node.args.args[0].arg IsLocalDict(input_arg_name, self.keys).visit(node.body)
[docs] def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: if not node.args.args: return input_arg_name = node.args.args[0].arg IsLocalDict(input_arg_name, self.keys).visit(node)
[docs] def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any: if not node.args.args: return input_arg_name = node.args.args[0].arg IsLocalDict(input_arg_name, self.keys).visit(node)
[docs]class NonLocals(ast.NodeVisitor): """获取已访问的非本地变量。"""
[docs] def __init__(self) -> None: self.loads: Set[str] = set() self.stores: Set[str] = set()
[docs] def visit_Name(self, node: ast.Name) -> Any: if isinstance(node.ctx, ast.Load): self.loads.add(node.id) elif isinstance(node.ctx, ast.Store): self.stores.add(node.id)
[docs] def visit_Attribute(self, node: ast.Attribute) -> Any: if isinstance(node.ctx, ast.Load): parent = node.value attr_expr = node.attr while isinstance(parent, ast.Attribute): attr_expr = parent.attr + "." + attr_expr parent = parent.value if isinstance(parent, ast.Name): self.loads.add(parent.id + "." + attr_expr) self.loads.discard(parent.id)
[docs]class FunctionNonLocals(ast.NodeVisitor): """获取函数访问的非局部变量。"""
[docs] def __init__(self) -> None: self.nonlocals: Set[str] = set()
[docs] def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: visitor = NonLocals() visitor.visit(node) self.nonlocals.update(visitor.loads - visitor.stores)
[docs] def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any: visitor = NonLocals() visitor.visit(node) self.nonlocals.update(visitor.loads - visitor.stores)
[docs] def visit_Lambda(self, node: ast.Lambda) -> Any: visitor = NonLocals() visitor.visit(node) self.nonlocals.update(visitor.loads - visitor.stores)
[docs]class GetLambdaSource(ast.NodeVisitor): """获取lambda函数的源代码。"""
[docs] def __init__(self) -> None: """初始化访问者。""" self.source: Optional[str] = None self.count = 0
[docs] def visit_Lambda(self, node: ast.Lambda) -> Any: """访问一个lambda函数。""" self.count += 1 if hasattr(ast, "unparse"): self.source = ast.unparse(node)
[docs]def get_function_first_arg_dict_keys(func: Callable) -> Optional[List[str]]: """如果第一个参数是字典,则获取函数的键。""" try: code = inspect.getsource(func) tree = ast.parse(textwrap.dedent(code)) visitor = IsFunctionArgDict() visitor.visit(tree) return list(visitor.keys) if visitor.keys else None except (SyntaxError, TypeError, OSError, SystemError): return None
[docs]def get_lambda_source(func: Callable) -> Optional[str]: """获取lambda函数的源代码。 Args: func: a callable that can be a lambda function Returns: str: the source code of the lambda function """ try: name = func.__name__ if func.__name__ != "<lambda>" else None except AttributeError: name = None try: code = inspect.getsource(func) tree = ast.parse(textwrap.dedent(code)) visitor = GetLambdaSource() visitor.visit(tree) return visitor.source if visitor.count == 1 else name except (SyntaxError, TypeError, OSError, SystemError): return name
[docs]def get_function_nonlocals(func: Callable) -> List[Any]: """获取函数访问的非局部变量。""" try: code = inspect.getsource(func) tree = ast.parse(textwrap.dedent(code)) visitor = FunctionNonLocals() visitor.visit(tree) values: List[Any] = [] for k, v in inspect.getclosurevars(func).nonlocals.items(): if k in visitor.nonlocals: values.append(v) for kk in visitor.nonlocals: if "." in kk and kk.startswith(k): vv = v for part in kk.split(".")[1:]: if vv is None: break else: try: vv = getattr(vv, part) except AttributeError: break else: values.append(vv) return values except (SyntaxError, TypeError, OSError, SystemError): return []
[docs]def indent_lines_after_first(text: str, prefix: str) -> str: """缩进第一行之后的所有文本。 参数: text:要缩进的文本 prefix:用于确定缩进的空格数 返回: str:缩进后的文本 """ n_spaces = len(prefix) spaces = " " * n_spaces lines = text.splitlines() return "\n".join([lines[0]] + [spaces + line for line in lines[1:]])
[docs]class AddableDict(Dict[str, Any]): """ 可以添加到另一个字典中的字典。 """ def __add__(self, other: AddableDict) -> AddableDict: chunk = AddableDict(self) for key in other: if key not in chunk or chunk[key] is None: chunk[key] = other[key] elif other[key] is not None: try: added = chunk[key] + other[key] except TypeError: added = other[key] chunk[key] = added return chunk def __radd__(self, other: AddableDict) -> AddableDict: chunk = AddableDict(other) for key in self: if key not in chunk or chunk[key] is None: chunk[key] = self[key] elif self[key] is not None: try: added = chunk[key] + self[key] except TypeError: added = self[key] chunk[key] = added return chunk
_T_co = TypeVar("_T_co", covariant=True) _T_contra = TypeVar("_T_contra", contravariant=True)
[docs]class SupportsAdd(Protocol[_T_contra, _T_co]): """支持加法的对象的协议。""" def __add__(self, __x: _T_contra) -> _T_co: ...
Addable = TypeVar("Addable", bound=SupportsAdd[Any, Any])
[docs]def add(addables: Iterable[Addable]) -> Optional[Addable]: """将一系列可相加的对象相加在一起。""" final = None for chunk in addables: if final is None: final = chunk else: final = final + chunk return final
[docs]async def aadd(addables: AsyncIterable[Addable]) -> Optional[Addable]: """异步地将一系列可相加的对象相加。""" final = None async for chunk in addables: if final is None: final = chunk else: final = final + chunk return final
[docs]class ConfigurableField(NamedTuple): """用户可以配置的字段。""" id: str name: Optional[str] = None description: Optional[str] = None annotation: Optional[Any] = None is_shared: bool = False def __hash__(self) -> int: return hash((self.id, self.annotation))
[docs]class ConfigurableFieldSingleOption(NamedTuple): """用户可以使用默认值配置的字段。""" id: str options: Mapping[str, Any] default: str name: Optional[str] = None description: Optional[str] = None is_shared: bool = False def __hash__(self) -> int: return hash((self.id, tuple(self.options.keys()), self.default))
[docs]class ConfigurableFieldMultiOption(NamedTuple): """用户可以使用多个默认值配置的字段。""" id: str options: Mapping[str, Any] default: Sequence[str] name: Optional[str] = None description: Optional[str] = None is_shared: bool = False def __hash__(self) -> int: return hash((self.id, tuple(self.options.keys()), tuple(self.default)))
AnyConfigurableField = Union[ ConfigurableField, ConfigurableFieldSingleOption, ConfigurableFieldMultiOption ]
[docs]class ConfigurableFieldSpec(NamedTuple): """用户可以配置的字段。这是一个字段的规范。""" id: str annotation: Any name: Optional[str] = None description: Optional[str] = None default: Any = None is_shared: bool = False dependencies: Optional[List[str]] = None
[docs]def get_unique_config_specs( specs: Iterable[ConfigurableFieldSpec], ) -> List[ConfigurableFieldSpec]: """从一系列配置规格中获取唯一的配置规格。""" grouped = groupby( sorted(specs, key=lambda s: (s.id, *(s.dependencies or []))), lambda s: s.id ) unique: List[ConfigurableFieldSpec] = [] for id, dupes in grouped: first = next(dupes) others = list(dupes) if len(others) == 0: unique.append(first) elif all(o == first for o in others): unique.append(first) else: raise ValueError( "RunnableSequence contains conflicting config specs" f"for {id}: {[first] + others}" ) return unique
class _RootEventFilter: def __init__( self, *, include_names: Optional[Sequence[str]] = None, include_types: Optional[Sequence[str]] = None, include_tags: Optional[Sequence[str]] = None, exclude_names: Optional[Sequence[str]] = None, exclude_types: Optional[Sequence[str]] = None, exclude_tags: Optional[Sequence[str]] = None, ) -> None: """用于过滤astream_events实现中的根事件的实用程序。 这只是将参数绑定到命名空间,以减少在astream_events实现中输入的一点内容。 """ self.include_names = include_names self.include_types = include_types self.include_tags = include_tags self.exclude_names = exclude_names self.exclude_types = exclude_types self.exclude_tags = exclude_tags def include_event(self, event: StreamEvent, root_type: str) -> bool: """确定是否包含一个事件。""" if ( self.include_names is None and self.include_types is None and self.include_tags is None ): include = True else: include = False event_tags = event.get("tags") or [] if self.include_names is not None: include = include or event["name"] in self.include_names if self.include_types is not None: include = include or root_type in self.include_types if self.include_tags is not None: include = include or any(tag in self.include_tags for tag in event_tags) if self.exclude_names is not None: include = include and event["name"] not in self.exclude_names if self.exclude_types is not None: include = include and root_type not in self.exclude_types if self.exclude_tags is not None: include = include and all( tag not in self.exclude_tags for tag in event_tags ) return include class _SchemaConfig(BaseConfig): arbitrary_types_allowed = True frozen = True
[docs]def create_model( __model_name: str, **field_definitions: Any, ) -> Type[BaseModel]: """使用给定的字段定义创建一个pydantic模型。 参数: __model_name: 模型的名称。 **field_definitions: 模型的字段定义。 返回: Type[BaseModel]: 创建的模型。 """ try: return _create_model_cached(__model_name, **field_definitions) except TypeError: # something in field definitions is not hashable return _create_model_base( __model_name, __config__=_SchemaConfig, **field_definitions )
@lru_cache(maxsize=256) def _create_model_cached( __model_name: str, **field_definitions: Any, ) -> Type[BaseModel]: return _create_model_base( __model_name, __config__=_SchemaConfig, **field_definitions )
[docs]def is_async_generator( func: Any, ) -> TypeGuard[Callable[..., AsyncIterator]]: """检查函数是否为异步生成器。""" return ( inspect.isasyncgenfunction(func) or hasattr(func, "__call__") and inspect.isasyncgenfunction(func.__call__) )
[docs]def is_async_callable( func: Any, ) -> TypeGuard[Callable[..., Awaitable]]: """检查函数是否是异步的。""" return ( asyncio.iscoroutinefunction(func) or hasattr(func, "__call__") and asyncio.iscoroutinefunction(func.__call__) )