Source code for langchain_core.runnables.branch

from typing import (
    Any,
    AsyncIterator,
    Awaitable,
    Callable,
    Iterator,
    List,
    Mapping,
    Optional,
    Sequence,
    Tuple,
    Type,
    Union,
    cast,
)

from langchain_core.load.dump import dumpd
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables.base import (
    Runnable,
    RunnableLike,
    RunnableSerializable,
    coerce_to_runnable,
)
from langchain_core.runnables.config import (
    RunnableConfig,
    ensure_config,
    get_async_callback_manager_for_config,
    get_callback_manager_for_config,
    patch_config,
)
from langchain_core.runnables.utils import (
    ConfigurableFieldSpec,
    Input,
    Output,
    get_unique_config_specs,
)


[docs]class RunnableBranch(RunnableSerializable[Input, Output]): """根据条件选择要运行的分支的可运行程序。 Runnable被初始化为一组(条件,Runnable)对和一个默认分支。 在处理输入时,选择第一个评估为True的条件,并在输入上运行相应的Runnable。 如果没有条件评估为True,则在输入上运行默认分支。 示例: .. code-block:: python from langchain_core.runnables import RunnableBranch branch = RunnableBranch( (lambda x: isinstance(x, str), lambda x: x.upper()), (lambda x: isinstance(x, int), lambda x: x + 1), (lambda x: isinstance(x, float), lambda x: x * 2), lambda x: "goodbye", ) branch.invoke("hello") # "HELLO" branch.invoke(None) # "goodbye" """ branches: Sequence[Tuple[Runnable[Input, bool], Runnable[Input, Output]]] default: Runnable[Input, Output] def __init__( self, *branches: Union[ Tuple[ Union[ Runnable[Input, bool], Callable[[Input], bool], Callable[[Input], Awaitable[bool]], ], RunnableLike, ], RunnableLike, # To accommodate the default branch ], ) -> None: """一个根据条件运行两个分支中的一个的可运行对象。""" if len(branches) < 2: raise ValueError("RunnableBranch requires at least two branches") default = branches[-1] if not isinstance( default, (Runnable, Callable, Mapping), # type: ignore[arg-type] ): raise TypeError( "RunnableBranch default must be runnable, callable or mapping." ) default_ = cast( Runnable[Input, Output], coerce_to_runnable(cast(RunnableLike, default)) ) _branches = [] for branch in branches[:-1]: if not isinstance(branch, (tuple, list)): # type: ignore[arg-type] raise TypeError( f"RunnableBranch branches must be " f"tuples or lists, not {type(branch)}" ) if not len(branch) == 2: raise ValueError( f"RunnableBranch branches must be " f"tuples or lists of length 2, not {len(branch)}" ) condition, runnable = branch condition = cast(Runnable[Input, bool], coerce_to_runnable(condition)) runnable = coerce_to_runnable(runnable) _branches.append((condition, runnable)) super().__init__(branches=_branches, default=default_) # type: ignore[call-arg] class Config: arbitrary_types_allowed = True
[docs] @classmethod def is_lc_serializable(cls) -> bool: """如果所有分支都是可序列化的,则RunnableBranch是可序列化的。""" return True
[docs] @classmethod def get_lc_namespace(cls) -> List[str]: """获取langchain对象的命名空间。""" return ["langchain", "schema", "runnable"]
[docs] def get_input_schema( self, config: Optional[RunnableConfig] = None ) -> Type[BaseModel]: runnables = ( [self.default] + [r for _, r in self.branches] + [r for r, _ in self.branches] ) for runnable in runnables: if runnable.get_input_schema(config).schema().get("type") is not None: return runnable.get_input_schema(config) return super().get_input_schema(config)
@property def config_specs(self) -> List[ConfigurableFieldSpec]: from langchain_core.beta.runnables.context import ( CONTEXT_CONFIG_PREFIX, CONTEXT_CONFIG_SUFFIX_SET, ) specs = get_unique_config_specs( spec for step in ( [self.default] + [r for _, r in self.branches] + [r for r, _ in self.branches] ) for spec in step.config_specs ) if any( s.id.startswith(CONTEXT_CONFIG_PREFIX) and s.id.endswith(CONTEXT_CONFIG_SUFFIX_SET) for s in specs ): raise ValueError("RunnableBranch cannot contain context setters.") return specs
[docs] def invoke( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Output: """首先评估条件, 然后委托给真或假分支。 """ config = ensure_config(config) callback_manager = get_callback_manager_for_config(config) run_manager = callback_manager.on_chain_start( dumpd(self), input, name=config.get("run_name"), run_id=config.pop("run_id", None), ) try: for idx, branch in enumerate(self.branches): condition, runnable = branch expression_value = condition.invoke( input, config=patch_config( config, callbacks=run_manager.get_child(tag=f"condition:{idx + 1}"), ), ) if expression_value: output = runnable.invoke( input, config=patch_config( config, callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"), ), **kwargs, ) break else: output = self.default.invoke( input, config=patch_config( config, callbacks=run_manager.get_child(tag="branch:default") ), **kwargs, ) except BaseException as e: run_manager.on_chain_error(e) raise run_manager.on_chain_end(dumpd(output)) return output
[docs] async def ainvoke( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Output: """异步版本的调用。""" config = ensure_config(config) callback_manager = get_async_callback_manager_for_config(config) run_manager = await callback_manager.on_chain_start( dumpd(self), input, name=config.get("run_name"), run_id=config.pop("run_id", None), ) try: for idx, branch in enumerate(self.branches): condition, runnable = branch expression_value = await condition.ainvoke( input, config=patch_config( config, callbacks=run_manager.get_child(tag=f"condition:{idx + 1}"), ), ) if expression_value: output = await runnable.ainvoke( input, config=patch_config( config, callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"), ), **kwargs, ) break else: output = await self.default.ainvoke( input, config=patch_config( config, callbacks=run_manager.get_child(tag="branch:default") ), **kwargs, ) except BaseException as e: await run_manager.on_chain_error(e) raise await run_manager.on_chain_end(dumpd(output)) return output
[docs] def stream( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> Iterator[Output]: """首先评估条件, 然后委托给true或false分支。 """ config = ensure_config(config) callback_manager = get_callback_manager_for_config(config) run_manager = callback_manager.on_chain_start( dumpd(self), input, name=config.get("run_name"), run_id=config.pop("run_id", None), ) final_output: Optional[Output] = None final_output_supported = True try: for idx, branch in enumerate(self.branches): condition, runnable = branch expression_value = condition.invoke( input, config=patch_config( config, callbacks=run_manager.get_child(tag=f"condition:{idx + 1}"), ), ) if expression_value: for chunk in runnable.stream( input, config=patch_config( config, callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"), ), **kwargs, ): yield chunk if final_output_supported: if final_output is None: final_output = chunk else: try: final_output = final_output + chunk # type: ignore except TypeError: final_output = None final_output_supported = False break else: for chunk in self.default.stream( input, config=patch_config( config, callbacks=run_manager.get_child(tag="branch:default"), ), **kwargs, ): yield chunk if final_output_supported: if final_output is None: final_output = chunk else: try: final_output = final_output + chunk # type: ignore except TypeError: final_output = None final_output_supported = False except BaseException as e: run_manager.on_chain_error(e) raise run_manager.on_chain_end(final_output)
[docs] async def astream( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> AsyncIterator[Output]: """首先评估条件, 然后委托给true或false分支。 """ config = ensure_config(config) callback_manager = get_async_callback_manager_for_config(config) run_manager = await callback_manager.on_chain_start( dumpd(self), input, name=config.get("run_name"), run_id=config.pop("run_id", None), ) final_output: Optional[Output] = None final_output_supported = True try: for idx, branch in enumerate(self.branches): condition, runnable = branch expression_value = await condition.ainvoke( input, config=patch_config( config, callbacks=run_manager.get_child(tag=f"condition:{idx + 1}"), ), ) if expression_value: async for chunk in runnable.astream( input, config=patch_config( config, callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"), ), **kwargs, ): yield chunk if final_output_supported: if final_output is None: final_output = chunk else: try: final_output = final_output + chunk # type: ignore except TypeError: final_output = None final_output_supported = False break else: async for chunk in self.default.astream( input, config=patch_config( config, callbacks=run_manager.get_child(tag="branch:default"), ), **kwargs, ): yield chunk if final_output_supported: if final_output is None: final_output = chunk else: try: final_output = final_output + chunk # type: ignore except TypeError: final_output = None final_output_supported = False except BaseException as e: await run_manager.on_chain_error(e) raise await run_manager.on_chain_end(final_output)