Source code for langchain.chains.base

"""所有链条都应该实现的基本接口。"""

import inspect
import json
import logging
import warnings
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Optional, Type, Union, cast

import yaml
from langchain_core._api import deprecated
from langchain_core.callbacks import (
    AsyncCallbackManager,
    AsyncCallbackManagerForChainRun,
    BaseCallbackManager,
    CallbackManager,
    CallbackManagerForChainRun,
    Callbacks,
)
from langchain_core.load.dump import dumpd
from langchain_core.memory import BaseMemory
from langchain_core.outputs import RunInfo
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator, validator
from langchain_core.runnables import (
    RunnableConfig,
    RunnableSerializable,
    ensure_config,
    run_in_executor,
)
from langchain_core.runnables.utils import create_model

from langchain.schema import RUN_KEY

logger = logging.getLogger(__name__)


def _get_verbosity() -> bool:
    from langchain.globals import get_verbose

    return get_verbose()


[docs]class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC): """用于创建组件调用序列的结构化序列的抽象基类。 应该使用Chains来编码对组件(如模型、文档检索器、其他链等)的调用序列,并为这个序列提供一个简单的接口。 Chain接口使得创建以下应用变得容易: - 有状态的:向任何Chain添加Memory以赋予其状态, - 可观察的:将Callbacks传递给Chain以执行额外功能,如日志记录,在组件调用主序列之外, - 可组合的:Chain API足够灵活,易于将Chains与其他组件(包括其他Chains)组合在一起。 Chains公开的主要方法有: - `__call__`:Chains是可调用的。`__call__`方法是执行Chain的主要方式。它接受字典形式的输入,并返回字典形式的输出。 - `run`:一个方便的方法,以args/kwargs形式接受输入,并将输出作为字符串或对象返回。此方法仅适用于部分链,并且无法像`__call__`那样返回丰富的输出。""" memory: Optional[BaseMemory] = None """可选的内存对象。默认值为None。 Memory是一个在每个链的开始和结束时被调用的类。在开始时,内存加载变量并在链中传递它们。在结束时,它保存任何返回的变量。 有许多不同类型的内存 - 请参阅内存文档以获取完整目录。""" callbacks: Callbacks = Field(default=None, exclude=True) """可选的回调处理程序列表(或回调管理器)。默认为None。 回调处理程序在调用链的整个生命周期中被调用,从on_chain_start开始,到on_chain_end或on_chain_error结束。 每个自定义链可以选择性地调用其他回调方法,请参阅回调文档以获取完整详情。""" verbose: bool = Field(default_factory=_get_verbosity) """是否在详细模式下运行。在详细模式下,一些中间日志将被打印到控制台。默认为全局`verbose`值,可通过`langchain.globals.get_verbose()`访问。""" tags: Optional[List[str]] = None """可选的与链相关联的标签列表。默认为None。 这些标签将与对该链的每次调用相关联, 并作为参数传递给在`callbacks`中定义的处理程序。 您可以使用这些标签来识别链的特定实例及其用例。""" metadata: Optional[Dict[str, Any]] = None """与链相关的可选元数据。默认为None。 此元数据将与对该链的每次调用相关联, 并作为参数传递给在`callbacks`中定义的处理程序。 您可以使用这些来识别链的特定实例及其用例。""" callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True) """[已弃用] 使用`callbacks`代替。""" class Config: """这个pydantic对象的配置。""" arbitrary_types_allowed = True
[docs] def get_input_schema( self, config: Optional[RunnableConfig] = None ) -> Type[BaseModel]: # This is correct, but pydantic typings/mypy don't think so. return create_model( # type: ignore[call-overload] "ChainInput", **{k: (Any, None) for k in self.input_keys} )
[docs] def get_output_schema( self, config: Optional[RunnableConfig] = None ) -> Type[BaseModel]: # This is correct, but pydantic typings/mypy don't think so. return create_model( # type: ignore[call-overload] "ChainOutput", **{k: (Any, None) for k in self.output_keys} )
[docs] def invoke( self, input: Dict[str, Any], config: Optional[RunnableConfig] = None, **kwargs: Any, ) -> Dict[str, Any]: config = ensure_config(config) callbacks = config.get("callbacks") tags = config.get("tags") metadata = config.get("metadata") run_name = config.get("run_name") or self.get_name() run_id = config.get("run_id") include_run_info = kwargs.get("include_run_info", False) return_only_outputs = kwargs.get("return_only_outputs", False) inputs = self.prep_inputs(input) callback_manager = CallbackManager.configure( callbacks, self.callbacks, self.verbose, tags, self.tags, metadata, self.metadata, ) new_arg_supported = inspect.signature(self._call).parameters.get("run_manager") run_manager = callback_manager.on_chain_start( dumpd(self), inputs, run_id, name=run_name, ) try: self._validate_inputs(inputs) outputs = ( self._call(inputs, run_manager=run_manager) if new_arg_supported else self._call(inputs) ) final_outputs: Dict[str, Any] = self.prep_outputs( inputs, outputs, return_only_outputs ) except BaseException as e: run_manager.on_chain_error(e) raise e run_manager.on_chain_end(outputs) if include_run_info: final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id) return final_outputs
[docs] async def ainvoke( self, input: Dict[str, Any], config: Optional[RunnableConfig] = None, **kwargs: Any, ) -> Dict[str, Any]: config = ensure_config(config) callbacks = config.get("callbacks") tags = config.get("tags") metadata = config.get("metadata") run_name = config.get("run_name") or self.get_name() run_id = config.get("run_id") include_run_info = kwargs.get("include_run_info", False) return_only_outputs = kwargs.get("return_only_outputs", False) inputs = await self.aprep_inputs(input) callback_manager = AsyncCallbackManager.configure( callbacks, self.callbacks, self.verbose, tags, self.tags, metadata, self.metadata, ) new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager") run_manager = await callback_manager.on_chain_start( dumpd(self), inputs, run_id, name=run_name, ) try: self._validate_inputs(inputs) outputs = ( await self._acall(inputs, run_manager=run_manager) if new_arg_supported else await self._acall(inputs) ) final_outputs: Dict[str, Any] = await self.aprep_outputs( inputs, outputs, return_only_outputs ) except BaseException as e: await run_manager.on_chain_error(e) raise e await run_manager.on_chain_end(outputs) if include_run_info: final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id) return final_outputs
@property def _chain_type(self) -> str: raise NotImplementedError("Saving not supported for this chain type.") @root_validator() def raise_callback_manager_deprecation(cls, values: Dict) -> Dict: """如果使用callback_manager,则发出弃用警告。""" if values.get("callback_manager") is not None: if values.get("callbacks") is not None: raise ValueError( "Cannot specify both callback_manager and callbacks. " "callback_manager is deprecated, callbacks is the preferred " "parameter to pass in." ) warnings.warn( "callback_manager is deprecated. Please use callbacks instead.", DeprecationWarning, ) values["callbacks"] = values.pop("callback_manager", None) return values @validator("verbose", pre=True, always=True) def set_verbose(cls, verbose: Optional[bool]) -> bool: """设置链的详细程度。 如果用户未指定,则默认为全局设置。 """ if verbose is None: return _get_verbosity() else: return verbose @property @abstractmethod def input_keys(self) -> List[str]: """预期在链输入中的键。""" @property @abstractmethod def output_keys(self) -> List[str]: """链输出中预期存在的键。""" def _validate_inputs(self, inputs: Dict[str, Any]) -> None: """检查所有输入是否都存在。""" if not isinstance(inputs, dict): _input_keys = set(self.input_keys) if self.memory is not None: # If there are multiple input keys, but some get set by memory so that # only one is not set, we can still figure out which key it is. _input_keys = _input_keys.difference(self.memory.memory_variables) if len(_input_keys) != 1: raise ValueError( f"A single string input was passed in, but this chain expects " f"multiple inputs ({_input_keys}). When a chain expects " f"multiple inputs, please call it by passing in a dictionary, " "eg `chain({'foo': 1, 'bar': 2})`" ) missing_keys = set(self.input_keys).difference(inputs) if missing_keys: raise ValueError(f"Missing some input keys: {missing_keys}") def _validate_outputs(self, outputs: Dict[str, Any]) -> None: missing_keys = set(self.output_keys).difference(outputs) if missing_keys: raise ValueError(f"Missing some output keys: {missing_keys}") @abstractmethod def _call( self, inputs: Dict[str, Any], run_manager: Optional[CallbackManagerForChainRun] = None, ) -> Dict[str, Any]: """执行链条。 这是一个私有方法,不会暴露给用户。它只会在`Chain.__call__`中被调用,该方法是用户可见的包装方法,用于处理回调配置和一些输入/输出处理。 参数: inputs: 一个包含链条中所有指定输入的命名输入字典,包括内存中添加的任何输入。 run_manager: 包含该链条运行的回调处理程序的回调管理器。 返回: 一个包含命名输出的字典。应该包含在`Chain.output_keys`中指定的所有输出。 """ async def _acall( self, inputs: Dict[str, Any], run_manager: Optional[AsyncCallbackManagerForChainRun] = None, ) -> Dict[str, Any]: """异步执行链。 这是一个私有方法,不向用户公开。它仅在`Chain.acall`内部调用,该方法是用户可见的包装方法,处理回调配置和一些输入/输出处理。 参数: inputs:链的命名输入字典。假定包含`Chain.input_keys`中指定的所有输入,包括内存中添加的任何输入。 run_manager:包含此链运行的回调处理程序的回调管理器。 返回: 命名输出的字典。应包含`Chain.output_keys`中指定的所有输出。 """ return await run_in_executor( None, self._call, inputs, run_manager.get_sync() if run_manager else None )
[docs] @deprecated("0.1.0", alternative="invoke", removal="0.3.0") def __call__( self, inputs: Union[Dict[str, Any], Any], return_only_outputs: bool = False, callbacks: Callbacks = None, *, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, run_name: Optional[str] = None, include_run_info: bool = False, ) -> Dict[str, Any]: """执行链。 参数: inputs: 输入的字典,或者如果链只需要一个参数,则为单个输入。应包含在`Chain.input_keys`中指定的所有输入,除了将由链的内存设置的输入。 return_only_outputs: 是否仅在响应中返回输出。如果为True,则仅返回此链生成的新键。如果为False,则返回输入键和此链生成的新键。默认为False。 callbacks: 用于此链运行的回调。除了在构造期间传递给链的回调之外,还将调用这些回调,但是只有这些运行时回调将传播到对其他对象的调用。 tags: 要传递给所有回调的字符串标签列表。这些将被传递给链在构造期间传递的标签之外,但是只有这些运行时标签将传播到对其他对象的调用。 metadata: 与链关联的可选元数据。默认为None include_run_info: 是否在响应中包含运行信息。默认为False。 返回: 一个命名输出的字典。应包含在`Chain.output_keys`中指定的所有输出。 """ config = { "callbacks": callbacks, "tags": tags, "metadata": metadata, "run_name": run_name, } return self.invoke( inputs, cast(RunnableConfig, {k: v for k, v in config.items() if v is not None}), return_only_outputs=return_only_outputs, include_run_info=include_run_info, )
[docs] @deprecated("0.1.0", alternative="ainvoke", removal="0.3.0") async def acall( self, inputs: Union[Dict[str, Any], Any], return_only_outputs: bool = False, callbacks: Callbacks = None, *, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, run_name: Optional[str] = None, include_run_info: bool = False, ) -> Dict[str, Any]: """异步执行链。 参数: inputs:输入的字典,或者如果链只需要一个参数,则为单个输入。应包含在`Chain.input_keys`中指定的所有输入,除了将由链的内存设置的输入。 return_only_outputs:是否仅在响应中返回输出。如果为True,则仅返回此链生成的新键。如果为False,则返回输入键和此链生成的新键。默认为False。 callbacks:用于此链运行的回调。除了在构造期间传递给链的回调之外,还将调用这些回调,但是只有这些运行时回调将传播到对其他对象的调用。 tags:要传递给所有回调的字符串标签列表。这些将被传递给链在构造期间传递的标签之外,但是只有这些运行时标签将传播到对其他对象的调用。 metadata:与链关联的可选元数据。默认为None。 include_run_info:是否在响应中包含运行信息。默认为False。 返回: 命名输出的字典。应包含在`Chain.output_keys`中指定的所有输出。 """ config = { "callbacks": callbacks, "tags": tags, "metadata": metadata, "run_name": run_name, } return await self.ainvoke( inputs, cast(RunnableConfig, {k: v for k, v in config.items() if k is not None}), return_only_outputs=return_only_outputs, include_run_info=include_run_info, )
[docs] def prep_outputs( self, inputs: Dict[str, str], outputs: Dict[str, str], return_only_outputs: bool = False, ) -> Dict[str, str]: """验证和准备链输出,并将此运行的信息保存到内存中。 参数: inputs:包括链内存中添加的任何输入的链输入字典。 outputs:初始链输出字典。 return_only_outputs:是否仅返回链输出。如果为False,则还将输入添加到最终输出中。 返回: 最终链输出的字典。 """ self._validate_outputs(outputs) if self.memory is not None: self.memory.save_context(inputs, outputs) if return_only_outputs: return outputs else: return {**inputs, **outputs}
[docs] async def aprep_outputs( self, inputs: Dict[str, str], outputs: Dict[str, str], return_only_outputs: bool = False, ) -> Dict[str, str]: """验证和准备链输出,并将此运行的信息保存到内存中。 参数: inputs:包括链内存中添加的任何输入的链输入字典。 outputs:初始链输出字典。 return_only_outputs:是否仅返回链输出。如果为False,则还将输入添加到最终输出中。 返回: 最终链输出的字典。 """ self._validate_outputs(outputs) if self.memory is not None: await self.memory.asave_context(inputs, outputs) if return_only_outputs: return outputs else: return {**inputs, **outputs}
[docs] def prep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]: """准备链的输入,包括从内存中添加输入。 参数: inputs: 原始输入的字典,或者如果链只需要一个参数,则为单个输入。应包含`Chain.input_keys`中指定的所有输入,除了将由链的内存设置的输入。 返回: 包括链的内存添加的所有输入的字典。 """ if not isinstance(inputs, dict): _input_keys = set(self.input_keys) if self.memory is not None: # If there are multiple input keys, but some get set by memory so that # only one is not set, we can still figure out which key it is. _input_keys = _input_keys.difference(self.memory.memory_variables) inputs = {list(_input_keys)[0]: inputs} if self.memory is not None: external_context = self.memory.load_memory_variables(inputs) inputs = dict(inputs, **external_context) return inputs
[docs] async def aprep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]: """准备链的输入,包括从内存中添加输入。 参数: inputs: 原始输入的字典,或者如果链只需要一个参数,则为单个输入。应包含`Chain.input_keys`中指定的所有输入,除了将由链的内存设置的输入。 返回: 包括链的内存添加的所有输入的字典。 """ if not isinstance(inputs, dict): _input_keys = set(self.input_keys) if self.memory is not None: # If there are multiple input keys, but some get set by memory so that # only one is not set, we can still figure out which key it is. _input_keys = _input_keys.difference(self.memory.memory_variables) inputs = {list(_input_keys)[0]: inputs} if self.memory is not None: external_context = await self.memory.aload_memory_variables(inputs) inputs = dict(inputs, **external_context) return inputs
@property def _run_output_key(self) -> str: if len(self.output_keys) != 1: raise ValueError( f"`run` not supported when there is not exactly " f"one output key. Got {self.output_keys}." ) return self.output_keys[0]
[docs] @deprecated("0.1.0", alternative="invoke", removal="0.3.0") def run( self, *args: Any, callbacks: Callbacks = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Any: """方便执行链的方法。 这个方法与`Chain.__call__`方法的主要区别在于,这个方法期望直接将输入作为位置参数或关键字参数传递,而`Chain.__call__`方法期望一个包含所有输入的单个输入字典。 参数: *args: 如果链期望单个输入,则可以将其作为唯一的位置参数传递。 callbacks: 用于此链运行的回调。这些回调将被调用,除了在构建过程中传递给链的回调之外,但只有这些运行时回调会传播到对其他对象的调用。 tags: 要传递给所有回调的字符串标签列表。这些标签将被传递,除了在构建过程中传递给链的标签之外,但只有这些运行时标签会传播到对其他对象的调用。 **kwargs: 如果链期望多个输入,则可以直接将它们作为关键字参数传递。 返回: 链的输出。 示例: .. code-block:: python # 假设我们有一个单输入链,接受一个'question'字符串: chain.run("What's the temperature in Boise, Idaho?") # -> "The temperature in Boise is..." # 假设我们有一个多输入链,接受一个'question'字符串和一个'context'字符串: question = "What's the temperature in Boise, Idaho?" context = "Weather report for Boise, Idaho on 07/03/23..." chain.run(question=question, context=context) # -> "The temperature in Boise is..." """ # Run at start to make sure this is possible/defined _output_key = self._run_output_key if args and not kwargs: if len(args) != 1: raise ValueError("`run` supports only one positional argument.") return self(args[0], callbacks=callbacks, tags=tags, metadata=metadata)[ _output_key ] if kwargs and not args: return self(kwargs, callbacks=callbacks, tags=tags, metadata=metadata)[ _output_key ] if not kwargs and not args: raise ValueError( "`run` supported with either positional arguments or keyword arguments," " but none were provided." ) else: raise ValueError( f"`run` supported with either positional arguments or keyword arguments" f" but not both. Got args: {args} and kwargs: {kwargs}." )
[docs] @deprecated("0.1.0", alternative="ainvoke", removal="0.3.0") async def arun( self, *args: Any, callbacks: Callbacks = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Any: """方便执行链的方法。 这个方法和`Chain.__call__`方法的主要区别在于,这个方法期望直接将输入作为位置参数或关键字参数传递,而`Chain.__call__`方法期望一个包含所有输入的单个输入字典。 参数: *args: 如果链期望单个输入,可以作为唯一的位置参数传入。 callbacks: 用于此链运行的回调。这些回调将会被调用,除了在构建期间传递给链的回调之外,但是只有这些运行时回调会传播到对其他对象的调用。 tags: 要传递给所有回调的字符串标签列表。这些标签将会被传递,除了在构建期间传递给链的标签之外,但是只有这些运行时标签会传播到对其他对象的调用。 **kwargs: 如果链期望多个输入,它们可以直接作为关键字参数传入。 返回值: 链的输出。 示例: .. code-block:: python # 假设我们有一个单输入链,接受一个'question'字符串: await chain.arun("What's the temperature in Boise, Idaho?") # -> "The temperature in Boise is..." # 假设我们有一个多输入链,接受一个'question'字符串和一个'context'字符串: question = "What's the temperature in Boise, Idaho?" context = "Weather report for Boise, Idaho on 07/03/23..." await chain.arun(question=question, context=context) # -> "The temperature in Boise is..." """ if len(self.output_keys) != 1: raise ValueError( f"`run` not supported when there is not exactly " f"one output key. Got {self.output_keys}." ) elif args and not kwargs: if len(args) != 1: raise ValueError("`run` supports only one positional argument.") return ( await self.acall( args[0], callbacks=callbacks, tags=tags, metadata=metadata ) )[self.output_keys[0]] if kwargs and not args: return ( await self.acall( kwargs, callbacks=callbacks, tags=tags, metadata=metadata ) )[self.output_keys[0]] raise ValueError( f"`run` supported with either positional arguments or keyword arguments" f" but not both. Got args: {args} and kwargs: {kwargs}." )
[docs] def dict(self, **kwargs: Any) -> Dict: """链的字典表示。 期望`Chain._chain_type`属性被实现,并且内存为空。 参数: **kwargs: 传递给默认`pydantic.BaseModel.dict`方法的关键字参数。 返回: 链的字典表示。 示例: .. code-block:: python chain.dict(exclude_unset=True) # -> {"_type": "foo", "verbose": False, ...} """ _dict = super().dict(**kwargs) try: _dict["_type"] = self._chain_type except NotImplementedError: pass return _dict
[docs] def save(self, file_path: Union[Path, str]) -> None: """保存链条。 期望`Chain._chain_type`属性已实现,并且内存为空。 参数: file_path:保存链条的文件路径。 示例: .. code-block:: python chain.save(file_path="path/chain.yaml") """ if self.memory is not None: raise ValueError("Saving of memory is not yet supported.") # Fetch dictionary to save chain_dict = self.dict() if "_type" not in chain_dict: raise NotImplementedError(f"Chain {self} does not support saving.") # Convert file to Path object. if isinstance(file_path, str): save_path = Path(file_path) else: save_path = file_path directory_path = save_path.parent directory_path.mkdir(parents=True, exist_ok=True) if save_path.suffix == ".json": with open(file_path, "w") as f: json.dump(chain_dict, f, indent=4) elif save_path.suffix.endswith((".yaml", ".yml")): with open(file_path, "w") as f: yaml.dump(chain_dict, f, default_flow_style=False) else: raise ValueError(f"{save_path} must be json or yaml")
[docs] @deprecated("0.1.0", alternative="batch", removal="0.3.0") def apply( self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None ) -> List[Dict[str, str]]: """对列表中的所有输入调用链。""" return [self(inputs, callbacks=callbacks) for inputs in input_list]