from typing import (
Any,
AsyncIterator,
Awaitable,
Callable,
Iterator,
List,
Mapping,
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
)
from langchain_core.load.dump import dumpd
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables.base import (
Runnable,
RunnableLike,
RunnableSerializable,
coerce_to_runnable,
)
from langchain_core.runnables.config import (
RunnableConfig,
ensure_config,
get_async_callback_manager_for_config,
get_callback_manager_for_config,
patch_config,
)
from langchain_core.runnables.utils import (
ConfigurableFieldSpec,
Input,
Output,
get_unique_config_specs,
)
[docs]class RunnableBranch(RunnableSerializable[Input, Output]):
"""根据条件选择要运行的分支的可运行程序。
Runnable被初始化为一组(条件,Runnable)对和一个默认分支。
在处理输入时,选择第一个评估为True的条件,并在输入上运行相应的Runnable。
如果没有条件评估为True,则在输入上运行默认分支。
示例:
.. code-block:: python
from langchain_core.runnables import RunnableBranch
branch = RunnableBranch(
(lambda x: isinstance(x, str), lambda x: x.upper()),
(lambda x: isinstance(x, int), lambda x: x + 1),
(lambda x: isinstance(x, float), lambda x: x * 2),
lambda x: "goodbye",
)
branch.invoke("hello") # "HELLO"
branch.invoke(None) # "goodbye"
"""
branches: Sequence[Tuple[Runnable[Input, bool], Runnable[Input, Output]]]
default: Runnable[Input, Output]
def __init__(
self,
*branches: Union[
Tuple[
Union[
Runnable[Input, bool],
Callable[[Input], bool],
Callable[[Input], Awaitable[bool]],
],
RunnableLike,
],
RunnableLike, # To accommodate the default branch
],
) -> None:
"""一个根据条件运行两个分支中的一个的可运行对象。"""
if len(branches) < 2:
raise ValueError("RunnableBranch requires at least two branches")
default = branches[-1]
if not isinstance(
default,
(Runnable, Callable, Mapping), # type: ignore[arg-type]
):
raise TypeError(
"RunnableBranch default must be runnable, callable or mapping."
)
default_ = cast(
Runnable[Input, Output], coerce_to_runnable(cast(RunnableLike, default))
)
_branches = []
for branch in branches[:-1]:
if not isinstance(branch, (tuple, list)): # type: ignore[arg-type]
raise TypeError(
f"RunnableBranch branches must be "
f"tuples or lists, not {type(branch)}"
)
if not len(branch) == 2:
raise ValueError(
f"RunnableBranch branches must be "
f"tuples or lists of length 2, not {len(branch)}"
)
condition, runnable = branch
condition = cast(Runnable[Input, bool], coerce_to_runnable(condition))
runnable = coerce_to_runnable(runnable)
_branches.append((condition, runnable))
super().__init__(branches=_branches, default=default_) # type: ignore[call-arg]
class Config:
arbitrary_types_allowed = True
[docs] @classmethod
def is_lc_serializable(cls) -> bool:
"""如果所有分支都是可序列化的,则RunnableBranch是可序列化的。"""
return True
[docs] @classmethod
def get_lc_namespace(cls) -> List[str]:
"""获取langchain对象的命名空间。"""
return ["langchain", "schema", "runnable"]
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
from langchain_core.beta.runnables.context import (
CONTEXT_CONFIG_PREFIX,
CONTEXT_CONFIG_SUFFIX_SET,
)
specs = get_unique_config_specs(
spec
for step in (
[self.default]
+ [r for _, r in self.branches]
+ [r for r, _ in self.branches]
)
for spec in step.config_specs
)
if any(
s.id.startswith(CONTEXT_CONFIG_PREFIX)
and s.id.endswith(CONTEXT_CONFIG_SUFFIX_SET)
for s in specs
):
raise ValueError("RunnableBranch cannot contain context setters.")
return specs
[docs] def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
"""首先评估条件,
然后委托给真或假分支。
"""
config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config)
run_manager = callback_manager.on_chain_start(
dumpd(self),
input,
name=config.get("run_name"),
run_id=config.pop("run_id", None),
)
try:
for idx, branch in enumerate(self.branches):
condition, runnable = branch
expression_value = condition.invoke(
input,
config=patch_config(
config,
callbacks=run_manager.get_child(tag=f"condition:{idx + 1}"),
),
)
if expression_value:
output = runnable.invoke(
input,
config=patch_config(
config,
callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"),
),
**kwargs,
)
break
else:
output = self.default.invoke(
input,
config=patch_config(
config, callbacks=run_manager.get_child(tag="branch:default")
),
**kwargs,
)
except BaseException as e:
run_manager.on_chain_error(e)
raise
run_manager.on_chain_end(dumpd(output))
return output
[docs] async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
"""异步版本的调用。"""
config = ensure_config(config)
callback_manager = get_async_callback_manager_for_config(config)
run_manager = await callback_manager.on_chain_start(
dumpd(self),
input,
name=config.get("run_name"),
run_id=config.pop("run_id", None),
)
try:
for idx, branch in enumerate(self.branches):
condition, runnable = branch
expression_value = await condition.ainvoke(
input,
config=patch_config(
config,
callbacks=run_manager.get_child(tag=f"condition:{idx + 1}"),
),
)
if expression_value:
output = await runnable.ainvoke(
input,
config=patch_config(
config,
callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"),
),
**kwargs,
)
break
else:
output = await self.default.ainvoke(
input,
config=patch_config(
config, callbacks=run_manager.get_child(tag="branch:default")
),
**kwargs,
)
except BaseException as e:
await run_manager.on_chain_error(e)
raise
await run_manager.on_chain_end(dumpd(output))
return output
[docs] def stream(
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Iterator[Output]:
"""首先评估条件,
然后委托给true或false分支。
"""
config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config)
run_manager = callback_manager.on_chain_start(
dumpd(self),
input,
name=config.get("run_name"),
run_id=config.pop("run_id", None),
)
final_output: Optional[Output] = None
final_output_supported = True
try:
for idx, branch in enumerate(self.branches):
condition, runnable = branch
expression_value = condition.invoke(
input,
config=patch_config(
config,
callbacks=run_manager.get_child(tag=f"condition:{idx + 1}"),
),
)
if expression_value:
for chunk in runnable.stream(
input,
config=patch_config(
config,
callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"),
),
**kwargs,
):
yield chunk
if final_output_supported:
if final_output is None:
final_output = chunk
else:
try:
final_output = final_output + chunk # type: ignore
except TypeError:
final_output = None
final_output_supported = False
break
else:
for chunk in self.default.stream(
input,
config=patch_config(
config,
callbacks=run_manager.get_child(tag="branch:default"),
),
**kwargs,
):
yield chunk
if final_output_supported:
if final_output is None:
final_output = chunk
else:
try:
final_output = final_output + chunk # type: ignore
except TypeError:
final_output = None
final_output_supported = False
except BaseException as e:
run_manager.on_chain_error(e)
raise
run_manager.on_chain_end(final_output)
[docs] async def astream(
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> AsyncIterator[Output]:
"""首先评估条件,
然后委托给true或false分支。
"""
config = ensure_config(config)
callback_manager = get_async_callback_manager_for_config(config)
run_manager = await callback_manager.on_chain_start(
dumpd(self),
input,
name=config.get("run_name"),
run_id=config.pop("run_id", None),
)
final_output: Optional[Output] = None
final_output_supported = True
try:
for idx, branch in enumerate(self.branches):
condition, runnable = branch
expression_value = await condition.ainvoke(
input,
config=patch_config(
config,
callbacks=run_manager.get_child(tag=f"condition:{idx + 1}"),
),
)
if expression_value:
async for chunk in runnable.astream(
input,
config=patch_config(
config,
callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"),
),
**kwargs,
):
yield chunk
if final_output_supported:
if final_output is None:
final_output = chunk
else:
try:
final_output = final_output + chunk # type: ignore
except TypeError:
final_output = None
final_output_supported = False
break
else:
async for chunk in self.default.astream(
input,
config=patch_config(
config,
callbacks=run_manager.get_child(tag="branch:default"),
),
**kwargs,
):
yield chunk
if final_output_supported:
if final_output is None:
final_output = chunk
else:
try:
final_output = final_output + chunk # type: ignore
except TypeError:
final_output = None
final_output_supported = False
except BaseException as e:
await run_manager.on_chain_error(e)
raise
await run_manager.on_chain_end(final_output)