Source code for langchain_core.runnables.history

from __future__ import annotations

import inspect
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Dict,
    List,
    Optional,
    Sequence,
    Type,
    Union,
)

from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.load.load import load
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables.base import Runnable, RunnableBindingBase, RunnableLambda
from langchain_core.runnables.passthrough import RunnablePassthrough
from langchain_core.runnables.utils import (
    ConfigurableFieldSpec,
    create_model,
    get_unique_config_specs,
)

if TYPE_CHECKING:
    from langchain_core.language_models import LanguageModelLike
    from langchain_core.messages import BaseMessage
    from langchain_core.runnables.config import RunnableConfig
    from langchain_core.tracers.schemas import Run


MessagesOrDictWithMessages = Union[Sequence["BaseMessage"], Dict[str, Any]]
GetSessionHistoryCallable = Callable[..., BaseChatMessageHistory]


[docs]class RunnableWithMessageHistory(RunnableBindingBase): """管理另一个Runnable的聊天消息历史记录的Runnable。 聊天消息历史记录是代表对话的一系列消息。 RunnableWithMessageHistory包装另一个Runnable并管理其聊天消息历史记录;它负责读取和更新聊天消息历史记录。 支持包装Runnable的输入和输出的格式如下描述。 RunnableWithMessageHistory必须始终使用包含适当参数的配置调用,以用于聊天消息历史记录工厂。 默认情况下,期望Runnable接受一个名为`session_id`的配置参数,它是一个字符串。此参数用于创建与给定session_id匹配的新聊天消息历史记录或查找现有的聊天消息历史记录。 在这种情况下,调用将如下所示: `with_history.invoke(..., config={"configurable": {"session_id": "bar"}})` ; 例如,``{"configurable": {"session_id": "<SESSION_ID>"}}``。 可以通过向`history_factory_config`参数传递`ConfigurableFieldSpec`对象列表来自定义配置(请参见下面的示例)。 在示例中,我们将使用具有内存实现的聊天消息历史记录,以便轻松进行实验并查看结果。 对于生产用例,您将希望使用持久实现的聊天消息历史记录,例如``RedisChatMessageHistory``。 """ # noqa: E501 get_session_history: GetSessionHistoryCallable input_messages_key: Optional[str] = None output_messages_key: Optional[str] = None history_messages_key: Optional[str] = None history_factory_config: Sequence[ConfigurableFieldSpec]
[docs] @classmethod def get_lc_namespace(cls) -> List[str]: """获取langchain对象的命名空间。""" return ["langchain", "schema", "runnable"]
def __init__( self, runnable: Union[ Runnable[ Union[MessagesOrDictWithMessages], Union[str, BaseMessage, MessagesOrDictWithMessages], ], LanguageModelLike, ], get_session_history: GetSessionHistoryCallable, *, input_messages_key: Optional[str] = None, output_messages_key: Optional[str] = None, history_messages_key: Optional[str] = None, history_factory_config: Optional[Sequence[ConfigurableFieldSpec]] = None, **kwargs: Any, ) -> None: """初始化RunnableWithMessageHistory。 参数: runnable: 要包装的基本Runnable。必须输入以下之一: 1. 一个BaseMessages序列 2. 一个包含所有消息的字典 3. 一个包含当前输入字符串/消息的键和历史消息的单独键的字典。如果输入键指向字符串,则将其视为历史中的HumanMessage。 必须输出以下之一: 1. 可视为AIMessage的字符串 2. BaseMessage或BaseMessages序列 3. 包含BaseMessage或BaseMessages序列键的字典 get_session_history: 返回新的BaseChatMessageHistory的函数。 该函数应接受一个字符串类型的`session_id`作为单个位置参数,并返回相应的聊天消息历史记录实例。 .. code-block:: python def get_session_history( session_id: str, *, user_id: Optional[str]=None ) -> BaseChatMessageHistory: ... 或者应接受与`session_history_config_specs`的键匹配的关键字参数,并返回相应的聊天消息历史记录实例。 .. code-block:: python def get_session_history( *, user_id: str, thread_id: str, ) -> BaseChatMessageHistory: ... input_messages_key: 如果基本runnable接受字典作为输入,则必须指定。 output_messages_key: 如果基本runnable返回字典作为输出,则必须指定。 history_messages_key: 如果基本runnable接受字典作为输入并期望一个单独的历史消息键,则必须指定。 history_factory_config: 配置应传递给聊天历史工厂的字段。有关更多详细信息,请参见“ConfigurableFieldSpec”。 指定这些允许您将多个配置键传递给get_session_history工厂。 **kwargs: 传递给父类“RunnableBindingBase” init的任意额外kwargs。 """ # noqa: E501 history_chain: Runnable = RunnableLambda( self._enter_history, self._aenter_history ).with_config(run_name="load_history") messages_key = history_messages_key or input_messages_key if messages_key: history_chain = RunnablePassthrough.assign( **{messages_key: history_chain} ).with_config(run_name="insert_history") bound = ( history_chain | runnable.with_listeners(on_end=self._exit_history) ).with_config(run_name="RunnableWithMessageHistory") if history_factory_config: _config_specs = history_factory_config else: # If not provided, then we'll use the default session_id field _config_specs = [ ConfigurableFieldSpec( id="session_id", annotation=str, name="Session ID", description="Unique identifier for a session.", default="", is_shared=True, ), ] super().__init__( get_session_history=get_session_history, input_messages_key=input_messages_key, output_messages_key=output_messages_key, bound=bound, history_messages_key=history_messages_key, history_factory_config=_config_specs, **kwargs, ) @property def config_specs(self) -> List[ConfigurableFieldSpec]: return get_unique_config_specs( super().config_specs + list(self.history_factory_config) )
[docs] def get_input_schema( self, config: Optional[RunnableConfig] = None ) -> Type[BaseModel]: super_schema = super().get_input_schema(config) if super_schema.__custom_root_type__ or not super_schema.schema().get( "properties" ): from langchain_core.messages import BaseMessage fields: Dict = {} if self.input_messages_key and self.history_messages_key: fields[self.input_messages_key] = ( Union[str, BaseMessage, Sequence[BaseMessage]], ..., ) elif self.input_messages_key: fields[self.input_messages_key] = (Sequence[BaseMessage], ...) else: fields["__root__"] = (Sequence[BaseMessage], ...) return create_model( # type: ignore[call-overload] "RunnableWithChatHistoryInput", **fields, ) else: return super_schema
def _get_input_messages( self, input_val: Union[str, BaseMessage, Sequence[BaseMessage], dict] ) -> List[BaseMessage]: from langchain_core.messages import BaseMessage if isinstance(input_val, dict): if self.input_messages_key: key = self.input_messages_key elif len(input_val) == 1: key = list(input_val.keys())[0] else: key = "input" input_val = input_val[key] if isinstance(input_val, str): from langchain_core.messages import HumanMessage return [HumanMessage(content=input_val)] elif isinstance(input_val, BaseMessage): return [input_val] elif isinstance(input_val, (list, tuple)): return list(input_val) else: raise ValueError( f"Expected str, BaseMessage, List[BaseMessage], or Tuple[BaseMessage]. " f"Got {input_val}." ) def _get_output_messages( self, output_val: Union[str, BaseMessage, Sequence[BaseMessage], dict] ) -> List[BaseMessage]: from langchain_core.messages import BaseMessage if isinstance(output_val, dict): if self.output_messages_key: key = self.output_messages_key elif len(output_val) == 1: key = list(output_val.keys())[0] else: key = "output" # If you are wrapping a chat model directly # The output is actually this weird generations object if key not in output_val and "generations" in output_val: output_val = output_val["generations"][0][0]["message"] else: output_val = output_val[key] if isinstance(output_val, str): from langchain_core.messages import AIMessage return [AIMessage(content=output_val)] elif isinstance(output_val, BaseMessage): return [output_val] elif isinstance(output_val, (list, tuple)): return list(output_val) else: raise ValueError() def _enter_history(self, input: Any, config: RunnableConfig) -> List[BaseMessage]: hist: BaseChatMessageHistory = config["configurable"]["message_history"] messages = hist.messages.copy() if not self.history_messages_key: # return all messages messages += self._get_input_messages(input) return messages async def _aenter_history( self, input: Dict[str, Any], config: RunnableConfig ) -> List[BaseMessage]: hist: BaseChatMessageHistory = config["configurable"]["message_history"] messages = (await hist.aget_messages()).copy() if not self.history_messages_key: # return all messages input_val = ( input if not self.input_messages_key else input[self.input_messages_key] ) messages += self._get_input_messages(input_val) return messages def _exit_history(self, run: Run, config: RunnableConfig) -> None: hist: BaseChatMessageHistory = config["configurable"]["message_history"] # Get the input messages inputs = load(run.inputs) input_messages = self._get_input_messages(inputs) # If historic messages were prepended to the input messages, remove them to # avoid adding duplicate messages to history. if not self.history_messages_key: historic_messages = config["configurable"]["message_history"].messages input_messages = input_messages[len(historic_messages) :] # Get the output messages output_val = load(run.outputs) output_messages = self._get_output_messages(output_val) hist.add_messages(input_messages + output_messages) def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig: config = super()._merge_configs(*configs) expected_keys = [field_spec.id for field_spec in self.history_factory_config] configurable = config.get("configurable", {}) missing_keys = set(expected_keys) - set(configurable.keys()) if missing_keys: example_input = {self.input_messages_key: "foo"} example_configurable = { missing_key: "[your-value-here]" for missing_key in missing_keys } example_config = {"configurable": example_configurable} raise ValueError( f"Missing keys {sorted(missing_keys)} in config['configurable'] " f"Expected keys are {sorted(expected_keys)}." f"When using via .invoke() or .stream(), pass in a config; " f"e.g., chain.invoke({example_input}, {example_config})" ) parameter_names = _get_parameter_names(self.get_session_history) if len(expected_keys) == 1: # If arity = 1, then invoke function by positional arguments message_history = self.get_session_history(configurable[expected_keys[0]]) else: # otherwise verify that names of keys patch and invoke by named arguments if set(expected_keys) != set(parameter_names): raise ValueError( f"Expected keys {sorted(expected_keys)} do not match parameter " f"names {sorted(parameter_names)} of get_session_history." ) message_history = self.get_session_history( **{key: configurable[key] for key in expected_keys} ) config["configurable"]["message_history"] = message_history return config
def _get_parameter_names(callable_: GetSessionHistoryCallable) -> List[str]: """获取可调用对象的参数名称。""" sig = inspect.signature(callable_) return list(sig.parameters.keys())