from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
cast,
)
from tenacity import (
AsyncRetrying,
RetryCallState,
RetryError,
Retrying,
retry_if_exception_type,
stop_after_attempt,
wait_exponential_jitter,
)
from langchain_core.runnables.base import Input, Output, RunnableBindingBase
from langchain_core.runnables.config import RunnableConfig, patch_config
if TYPE_CHECKING:
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
T = TypeVar("T", CallbackManagerForChainRun, AsyncCallbackManagerForChainRun)
U = TypeVar("U")
[docs]class RunnableRetry(RunnableBindingBase[Input, Output]):
"""重新运行一个Runnable如果它失败。
RunnableRetry可用于向任何子类化基本Runnable的对象添加重试逻辑。
这种重试对于由于瞬态错误而可能失败的网络调用特别有用。
RunnableRetry被实现为RunnableBinding。最简单的使用方法是通过所有Runnables上的`.with_retry()`方法。
示例:
以下是一个使用RunnableLambda引发异常的示例
.. code-block:: python
import time
def foo(input) -> None:
'''引发异常的虚拟函数。'''
raise ValueError(f"调用foo失败。在时间 {time.time()}")
runnable = RunnableLambda(foo)
runnable_with_retries = runnable.with_retry(
retry_if_exception_type=(ValueError,), # 仅在ValueError时重试
wait_exponential_jitter=True, # 在指数退避中添加抖动
stop_after_attempt=2, # 尝试两次
)
# 上面的方法调用等同于下面的较长形式:
runnable_with_retries = RunnableRetry(
bound=runnable,
retry_exception_types=(ValueError,),
max_attempt_number=2,
wait_exponential_jitter=True
)
此逻辑可用于重试任何Runnable,包括一系列Runnables,但通常最好将重试的范围保持尽可能小。
例如,如果您有一系列Runnables,应该只重试可能失败的Runnable,而不是整个链条。
示例:
.. code-block:: python
from langchain_core.chat_models import ChatOpenAI
from langchain_core.prompts import PromptTemplate
template = PromptTemplate.from_template("告诉我一个关于{topic}的笑话。")
model = ChatOpenAI(temperature=0.5)
# 好的
chain = template | model.with_retry()
# 不好的
chain = template | model
retryable_chain = chain.with_retry()
"""
retry_exception_types: Tuple[Type[BaseException], ...] = (Exception,)
"""要重试的异常类型。默认情况下,所有异常都会被重试。
一般来说,你应该只在可能是临时的异常上重试,比如网络错误。
适合重试的良好异常包括所有服务器错误(5xx)和选定的客户端错误(4xx),比如429 Too Many Requests。"""
wait_exponential_jitter: bool = True
"""是否将抖动添加到指数退避中。"""
max_attempt_number: int = 3
"""重试可运行的最大尝试次数。"""
[docs] @classmethod
def get_lc_namespace(cls) -> List[str]:
"""获取langchain对象的命名空间。"""
return ["langchain", "schema", "runnable"]
@property
def _kwargs_retrying(self) -> Dict[str, Any]:
kwargs: Dict[str, Any] = dict()
if self.max_attempt_number:
kwargs["stop"] = stop_after_attempt(self.max_attempt_number)
if self.wait_exponential_jitter:
kwargs["wait"] = wait_exponential_jitter()
if self.retry_exception_types:
kwargs["retry"] = retry_if_exception_type(self.retry_exception_types)
return kwargs
def _sync_retrying(self, **kwargs: Any) -> Retrying:
return Retrying(**self._kwargs_retrying, **kwargs)
def _async_retrying(self, **kwargs: Any) -> AsyncRetrying:
return AsyncRetrying(**self._kwargs_retrying, **kwargs)
def _patch_config(
self,
config: RunnableConfig,
run_manager: "T",
retry_state: RetryCallState,
) -> RunnableConfig:
attempt = retry_state.attempt_number
tag = "retry:attempt:{}".format(attempt) if attempt > 1 else None
return patch_config(config, callbacks=run_manager.get_child(tag))
def _patch_config_list(
self,
config: List[RunnableConfig],
run_manager: List["T"],
retry_state: RetryCallState,
) -> List[RunnableConfig]:
return [
self._patch_config(c, rm, retry_state) for c, rm in zip(config, run_manager)
]
def _invoke(
self,
input: Input,
run_manager: "CallbackManagerForChainRun",
config: RunnableConfig,
**kwargs: Any,
) -> Output:
for attempt in self._sync_retrying(reraise=True):
with attempt:
result = super().invoke(
input,
self._patch_config(config, run_manager, attempt.retry_state),
**kwargs,
)
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed:
attempt.retry_state.set_result(result)
return result
[docs] def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
return self._call_with_config(self._invoke, input, config, **kwargs)
async def _ainvoke(
self,
input: Input,
run_manager: "AsyncCallbackManagerForChainRun",
config: RunnableConfig,
**kwargs: Any,
) -> Output:
async for attempt in self._async_retrying(reraise=True):
with attempt:
result = await super().ainvoke(
input,
self._patch_config(config, run_manager, attempt.retry_state),
**kwargs,
)
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed:
attempt.retry_state.set_result(result)
return result
[docs] async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
return await self._acall_with_config(self._ainvoke, input, config, **kwargs)
def _batch(
self,
inputs: List[Input],
run_manager: List["CallbackManagerForChainRun"],
config: List[RunnableConfig],
**kwargs: Any,
) -> List[Union[Output, Exception]]:
results_map: Dict[int, Output] = {}
def pending(iterable: List[U]) -> List[U]:
return [item for idx, item in enumerate(iterable) if idx not in results_map]
try:
for attempt in self._sync_retrying():
with attempt:
# Get the results of the inputs that have not succeeded yet.
result = super().batch(
pending(inputs),
self._patch_config_list(
pending(config), pending(run_manager), attempt.retry_state
),
return_exceptions=True,
**kwargs,
)
# Register the results of the inputs that have succeeded.
first_exception = None
for i, r in enumerate(result):
if isinstance(r, Exception):
if not first_exception:
first_exception = r
continue
results_map[i] = r
# If any exception occurred, raise it, to retry the failed ones
if first_exception:
raise first_exception
if (
attempt.retry_state.outcome
and not attempt.retry_state.outcome.failed
):
attempt.retry_state.set_result(result)
except RetryError as e:
try:
result
except UnboundLocalError:
result = cast(List[Output], [e] * len(inputs))
outputs: List[Union[Output, Exception]] = []
for idx, _ in enumerate(inputs):
if idx in results_map:
outputs.append(results_map[idx])
else:
outputs.append(result.pop(0))
return outputs
[docs] def batch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Any,
) -> List[Output]:
return self._batch_with_config(
self._batch, inputs, config, return_exceptions=return_exceptions, **kwargs
)
async def _abatch(
self,
inputs: List[Input],
run_manager: List["AsyncCallbackManagerForChainRun"],
config: List[RunnableConfig],
**kwargs: Any,
) -> List[Union[Output, Exception]]:
results_map: Dict[int, Output] = {}
def pending(iterable: List[U]) -> List[U]:
return [item for idx, item in enumerate(iterable) if idx not in results_map]
try:
async for attempt in self._async_retrying():
with attempt:
# Get the results of the inputs that have not succeeded yet.
result = await super().abatch(
pending(inputs),
self._patch_config_list(
pending(config), pending(run_manager), attempt.retry_state
),
return_exceptions=True,
**kwargs,
)
# Register the results of the inputs that have succeeded.
first_exception = None
for i, r in enumerate(result):
if isinstance(r, Exception):
if not first_exception:
first_exception = r
continue
results_map[i] = r
# If any exception occurred, raise it, to retry the failed ones
if first_exception:
raise first_exception
if (
attempt.retry_state.outcome
and not attempt.retry_state.outcome.failed
):
attempt.retry_state.set_result(result)
except RetryError as e:
try:
result
except UnboundLocalError:
result = cast(List[Output], [e] * len(inputs))
outputs: List[Union[Output, Exception]] = []
for idx, _ in enumerate(inputs):
if idx in results_map:
outputs.append(results_map[idx])
else:
outputs.append(result.pop(0))
return outputs
[docs] async def abatch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Any,
) -> List[Output]:
return await self._abatch_with_config(
self._abatch, inputs, config, return_exceptions=return_exceptions, **kwargs
)
# stream() and transform() are not retried because retrying a stream
# is not very intuitive.