Source code for langchain.chains.router.base
"""用于链路路由的基类。"""
from __future__ import annotations
from abc import ABC
from typing import Any, Dict, List, Mapping, NamedTuple, Optional
from langchain_core.callbacks import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
Callbacks,
)
from langchain_core.pydantic_v1 import Extra
from langchain.chains.base import Chain
[docs]class Route(NamedTuple):
destination: Optional[str]
next_inputs: Dict[str, Any]
[docs]class RouterChain(Chain, ABC):
"""链条,输出目标链的名称和其输入。"""
@property
def output_keys(self) -> List[str]:
return ["destination", "next_inputs"]
[docs] def route(self, inputs: Dict[str, Any], callbacks: Callbacks = None) -> Route:
"""将输入路由到目标链。
参数:
inputs:链的输入
callbacks:用于链的回调函数
返回:
一个Route对象
"""
result = self(inputs, callbacks=callbacks)
return Route(result["destination"], result["next_inputs"])
[docs] async def aroute(
self, inputs: Dict[str, Any], callbacks: Callbacks = None
) -> Route:
result = await self.acall(inputs, callbacks=callbacks)
return Route(result["destination"], result["next_inputs"])
[docs]class MultiRouteChain(Chain):
"""使用单一链将输入路由到多个候选链中的一个。"""
router_chain: RouterChain
"""将输入路由到目标链的链。"""
destination_chains: Mapping[str, Chain]
"""返回最终答案给输入的链条。"""
default_chain: Chain
"""当没有适合的目标链时使用的默认链。"""
silent_errors: bool = False
"""如果为True,则在提供无效的目标名称时使用默认链。默认为False。"""
class Config:
"""这个pydantic对象的配置。"""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def input_keys(self) -> List[str]:
"""将是路由器链提示期望的任何键。
:meta private:
"""
return self.router_chain.input_keys
@property
def output_keys(self) -> List[str]:
"""将始终返回文本键。
:元数据 私有:
"""
return []
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
route = self.router_chain.route(inputs, callbacks=callbacks)
_run_manager.on_text(
str(route.destination) + ": " + str(route.next_inputs), verbose=self.verbose
)
if not route.destination:
return self.default_chain(route.next_inputs, callbacks=callbacks)
elif route.destination in self.destination_chains:
return self.destination_chains[route.destination](
route.next_inputs, callbacks=callbacks
)
elif self.silent_errors:
return self.default_chain(route.next_inputs, callbacks=callbacks)
else:
raise ValueError(
f"Received invalid destination chain name '{route.destination}'"
)
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
route = await self.router_chain.aroute(inputs, callbacks=callbacks)
await _run_manager.on_text(
str(route.destination) + ": " + str(route.next_inputs), verbose=self.verbose
)
if not route.destination:
return await self.default_chain.acall(
route.next_inputs, callbacks=callbacks
)
elif route.destination in self.destination_chains:
return await self.destination_chains[route.destination].acall(
route.next_inputs, callbacks=callbacks
)
elif self.silent_errors:
return await self.default_chain.acall(
route.next_inputs, callbacks=callbacks
)
else:
raise ValueError(
f"Received invalid destination chain name '{route.destination}'"
)