import asyncio
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Awaitable,
Dict,
Iterator,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
)
from langchain_core.load.dump import dumpd
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables.base import Runnable, RunnableSerializable
from langchain_core.runnables.config import (
RunnableConfig,
ensure_config,
get_async_callback_manager_for_config,
get_callback_manager_for_config,
get_config_list,
patch_config,
)
from langchain_core.runnables.utils import (
ConfigurableFieldSpec,
Input,
Output,
get_unique_config_specs,
)
from langchain_core.utils.aiter import py_anext
if TYPE_CHECKING:
from langchain_core.callbacks.manager import AsyncCallbackManagerForChainRun
[docs]class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
"""可在失败时回退到其他Runnable的可运行对象。
外部API(例如,用于语言模型的API)有时可能会出现性能下降甚至停机的情况。
在这些情况下,有一个可以用作原始Runnable的替代Runnable会很有用(例如,回退到另一个LLM提供者)。
可以在单个Runnable或一系列Runnables的级别定义回退。回退按顺序尝试,直到一个成功或全部失败。
虽然可以直接实例化``RunnableWithFallbacks``,但通常更方便的是在Runnable上使用``with_fallbacks``方法。
示例:
.. code-block:: python
from langchain_core.chat_models.openai import ChatOpenAI
from langchain_core.chat_models.anthropic import ChatAnthropic
model = ChatAnthropic(
model="claude-3-haiku-20240307"
).with_fallbacks([ChatOpenAI(model="gpt-3.5-turbo-0125")])
# 通常会使用ChatAnthropic,但如果ChatAnthropic失败,则回退到ChatOpenAI。
model.invoke('hello')
# 也可以在链的级别使用回退。
# 如果两个LLM提供者都失败,我们将回退到一个良好的硬编码响应。
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parser import StrOutputParser
from langchain_core.runnables import RunnableLambda
def when_all_is_lost(inputs):
return ("看起来我们的LLM提供者出问题了。"
"这里有一个漂亮的🦜️表情符号给你。")
chain_with_fallback = (
PromptTemplate.from_template('Tell me a joke about {topic}')
| model
| StrOutputParser()
).with_fallbacks([RunnableLambda(when_all_is_lost)])
"""
runnable: Runnable[Input, Output]
"""首先运行的可运行程序。"""
fallbacks: Sequence[Runnable[Input, Output]]
"""一系列要尝试的备用方案。"""
exceptions_to_handle: Tuple[Type[BaseException], ...] = (Exception,)
"""应尝试回退的异常。
任何不是这些异常子类的异常将立即被触发。"""
exception_key: Optional[str] = None
"""如果指定了字符串,则处理异常将作为输入的一部分传递给指定键下的回退。如果为None,则异常将不会传递给回退。如果使用,则基本可运行及其回退必须接受字典作为输入。"""
class Config:
arbitrary_types_allowed = True
@property
def InputType(self) -> Type[Input]:
return self.runnable.InputType
@property
def OutputType(self) -> Type[Output]:
return self.runnable.OutputType
[docs] def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
return self.runnable.get_output_schema(config)
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
return get_unique_config_specs(
spec
for step in [self.runnable, *self.fallbacks]
for spec in step.config_specs
)
[docs] @classmethod
def is_lc_serializable(cls) -> bool:
return True
[docs] @classmethod
def get_lc_namespace(cls) -> List[str]:
"""获取langchain对象的命名空间。"""
return ["langchain", "schema", "runnable"]
@property
def runnables(self) -> Iterator[Runnable[Input, Output]]:
yield self.runnable
yield from self.fallbacks
[docs] def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
if self.exception_key is not None and not isinstance(input, dict):
raise ValueError(
"If 'exception_key' is specified then input must be a dictionary."
f"However found a type of {type(input)} for input"
)
# setup callbacks
config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config)
# start the root run
run_manager = callback_manager.on_chain_start(
dumpd(self),
input,
name=config.get("run_name"),
run_id=config.pop("run_id", None),
)
first_error = None
last_error = None
for runnable in self.runnables:
try:
if self.exception_key and last_error is not None:
input[self.exception_key] = last_error
output = runnable.invoke(
input,
patch_config(config, callbacks=run_manager.get_child()),
**kwargs,
)
except self.exceptions_to_handle as e:
if first_error is None:
first_error = e
last_error = e
except BaseException as e:
run_manager.on_chain_error(e)
raise e
else:
run_manager.on_chain_end(output)
return output
if first_error is None:
raise ValueError("No error stored at end of fallbacks.")
run_manager.on_chain_error(first_error)
raise first_error
[docs] async def ainvoke(
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Output:
if self.exception_key is not None and not isinstance(input, dict):
raise ValueError(
"If 'exception_key' is specified then input must be a dictionary."
f"However found a type of {type(input)} for input"
)
# setup callbacks
config = ensure_config(config)
callback_manager = get_async_callback_manager_for_config(config)
# start the root run
run_manager = await callback_manager.on_chain_start(
dumpd(self),
input,
name=config.get("run_name"),
run_id=config.pop("run_id", None),
)
first_error = None
last_error = None
for runnable in self.runnables:
try:
if self.exception_key and last_error is not None:
input[self.exception_key] = last_error
output = await runnable.ainvoke(
input,
patch_config(config, callbacks=run_manager.get_child()),
**kwargs,
)
except self.exceptions_to_handle as e:
if first_error is None:
first_error = e
last_error = e
except BaseException as e:
await run_manager.on_chain_error(e)
raise e
else:
await run_manager.on_chain_end(output)
return output
if first_error is None:
raise ValueError("No error stored at end of fallbacks.")
await run_manager.on_chain_error(first_error)
raise first_error
[docs] def batch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
from langchain_core.callbacks.manager import CallbackManager
if self.exception_key is not None and not all(
isinstance(input, dict) for input in inputs
):
raise ValueError(
"If 'exception_key' is specified then inputs must be dictionaries."
f"However found a type of {type(inputs[0])} for input"
)
if not inputs:
return []
# setup callbacks
configs = get_config_list(config, len(inputs))
callback_managers = [
CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
for config in configs
]
# start the root runs, one per input
run_managers = [
cm.on_chain_start(
dumpd(self),
input if isinstance(input, dict) else {"input": input},
name=config.get("run_name"),
run_id=config.pop("run_id", None),
)
for cm, input, config in zip(callback_managers, inputs, configs)
]
to_return: Dict[int, Any] = {}
run_again = {i: input for i, input in enumerate(inputs)}
handled_exceptions: Dict[int, BaseException] = {}
first_to_raise = None
for runnable in self.runnables:
outputs = runnable.batch(
[input for _, input in sorted(run_again.items())],
[
# each step a child run of the corresponding root run
patch_config(configs[i], callbacks=run_managers[i].get_child())
for i in sorted(run_again)
],
return_exceptions=True,
**kwargs,
)
for (i, input), output in zip(sorted(run_again.copy().items()), outputs):
if isinstance(output, BaseException) and not isinstance(
output, self.exceptions_to_handle
):
if not return_exceptions:
first_to_raise = first_to_raise or output
else:
handled_exceptions[i] = cast(BaseException, output)
run_again.pop(i)
elif isinstance(output, self.exceptions_to_handle):
if self.exception_key:
input[self.exception_key] = output # type: ignore
handled_exceptions[i] = cast(BaseException, output)
else:
run_managers[i].on_chain_end(output)
to_return[i] = output
run_again.pop(i)
handled_exceptions.pop(i, None)
if first_to_raise:
raise first_to_raise
if not run_again:
break
sorted_handled_exceptions = sorted(handled_exceptions.items())
for i, error in sorted_handled_exceptions:
run_managers[i].on_chain_error(error)
if not return_exceptions and sorted_handled_exceptions:
raise sorted_handled_exceptions[0][1]
to_return.update(handled_exceptions)
return [output for _, output in sorted(to_return.items())]
[docs] async def abatch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
from langchain_core.callbacks.manager import AsyncCallbackManager
if self.exception_key is not None and not all(
isinstance(input, dict) for input in inputs
):
raise ValueError(
"If 'exception_key' is specified then inputs must be dictionaries."
f"However found a type of {type(inputs[0])} for input"
)
if not inputs:
return []
# setup callbacks
configs = get_config_list(config, len(inputs))
callback_managers = [
AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
for config in configs
]
# start the root runs, one per input
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
*(
cm.on_chain_start(
dumpd(self),
input,
name=config.get("run_name"),
run_id=config.pop("run_id", None),
)
for cm, input, config in zip(callback_managers, inputs, configs)
)
)
to_return = {}
run_again = {i: input for i, input in enumerate(inputs)}
handled_exceptions: Dict[int, BaseException] = {}
first_to_raise = None
for runnable in self.runnables:
outputs = await runnable.abatch(
[input for _, input in sorted(run_again.items())],
[
# each step a child run of the corresponding root run
patch_config(configs[i], callbacks=run_managers[i].get_child())
for i in sorted(run_again)
],
return_exceptions=True,
**kwargs,
)
for (i, input), output in zip(sorted(run_again.copy().items()), outputs):
if isinstance(output, BaseException) and not isinstance(
output, self.exceptions_to_handle
):
if not return_exceptions:
first_to_raise = first_to_raise or output
else:
handled_exceptions[i] = cast(BaseException, output)
run_again.pop(i)
elif isinstance(output, self.exceptions_to_handle):
if self.exception_key:
input[self.exception_key] = output # type: ignore
handled_exceptions[i] = cast(BaseException, output)
else:
to_return[i] = output
await run_managers[i].on_chain_end(output)
run_again.pop(i)
handled_exceptions.pop(i, None)
if first_to_raise:
raise first_to_raise
if not run_again:
break
sorted_handled_exceptions = sorted(handled_exceptions.items())
await asyncio.gather(
*(
run_managers[i].on_chain_error(error)
for i, error in sorted_handled_exceptions
)
)
if not return_exceptions and sorted_handled_exceptions:
raise sorted_handled_exceptions[0][1]
to_return.update(handled_exceptions)
return [output for _, output in sorted(to_return.items())] # type: ignore
[docs] def stream(
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Iterator[Output]:
""""""
if self.exception_key is not None and not isinstance(input, dict):
raise ValueError(
"If 'exception_key' is specified then input must be a dictionary."
f"However found a type of {type(input)} for input"
)
# setup callbacks
config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config)
# start the root run
run_manager = callback_manager.on_chain_start(
dumpd(self),
input,
name=config.get("run_name"),
run_id=config.pop("run_id", None),
)
first_error = None
last_error = None
for runnable in self.runnables:
try:
if self.exception_key and last_error is not None:
input[self.exception_key] = last_error
stream = runnable.stream(
input,
patch_config(config, callbacks=run_manager.get_child()),
**kwargs,
)
chunk = next(stream)
except self.exceptions_to_handle as e:
first_error = e if first_error is None else first_error
last_error = e
except BaseException as e:
run_manager.on_chain_error(e)
raise e
else:
first_error = None
break
if first_error:
run_manager.on_chain_error(first_error)
raise first_error
yield chunk
output: Optional[Output] = chunk
try:
for chunk in stream:
yield chunk
try:
output = output + chunk # type: ignore
except TypeError:
output = None
except BaseException as e:
run_manager.on_chain_error(e)
raise e
run_manager.on_chain_end(output)
[docs] async def astream(
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> AsyncIterator[Output]:
if self.exception_key is not None and not isinstance(input, dict):
raise ValueError(
"If 'exception_key' is specified then input must be a dictionary."
f"However found a type of {type(input)} for input"
)
# setup callbacks
config = ensure_config(config)
callback_manager = get_async_callback_manager_for_config(config)
# start the root run
run_manager = await callback_manager.on_chain_start(
dumpd(self),
input,
name=config.get("run_name"),
run_id=config.pop("run_id", None),
)
first_error = None
last_error = None
for runnable in self.runnables:
try:
if self.exception_key and last_error is not None:
input[self.exception_key] = last_error
stream = runnable.astream(
input,
patch_config(config, callbacks=run_manager.get_child()),
**kwargs,
)
chunk = await cast(Awaitable[Output], py_anext(stream))
except self.exceptions_to_handle as e:
first_error = e if first_error is None else first_error
last_error = e
except BaseException as e:
await run_manager.on_chain_error(e)
raise e
else:
first_error = None
break
if first_error:
await run_manager.on_chain_error(first_error)
raise first_error
yield chunk
output: Optional[Output] = chunk
try:
async for chunk in stream:
yield chunk
try:
output = output + chunk # type: ignore
except TypeError:
output = None
except BaseException as e:
await run_manager.on_chain_error(e)
raise e
await run_manager.on_chain_end(output)