Source code for langchain_core.runnables.retry

from typing import (
    TYPE_CHECKING,
    Any,
    Dict,
    List,
    Optional,
    Tuple,
    Type,
    TypeVar,
    Union,
    cast,
)

from tenacity import (
    AsyncRetrying,
    RetryCallState,
    RetryError,
    Retrying,
    retry_if_exception_type,
    stop_after_attempt,
    wait_exponential_jitter,
)

from langchain_core.runnables.base import Input, Output, RunnableBindingBase
from langchain_core.runnables.config import RunnableConfig, patch_config

if TYPE_CHECKING:
    from langchain_core.callbacks.manager import (
        AsyncCallbackManagerForChainRun,
        CallbackManagerForChainRun,
    )

    T = TypeVar("T", CallbackManagerForChainRun, AsyncCallbackManagerForChainRun)
U = TypeVar("U")


[docs]class RunnableRetry(RunnableBindingBase[Input, Output]): """重新运行一个Runnable如果它失败。 RunnableRetry可用于向任何子类化基本Runnable的对象添加重试逻辑。 这种重试对于由于瞬态错误而可能失败的网络调用特别有用。 RunnableRetry被实现为RunnableBinding。最简单的使用方法是通过所有Runnables上的`.with_retry()`方法。 示例: 以下是一个使用RunnableLambda引发异常的示例 .. code-block:: python import time def foo(input) -> None: '''引发异常的虚拟函数。''' raise ValueError(f"调用foo失败。在时间 {time.time()}") runnable = RunnableLambda(foo) runnable_with_retries = runnable.with_retry( retry_if_exception_type=(ValueError,), # 仅在ValueError时重试 wait_exponential_jitter=True, # 在指数退避中添加抖动 stop_after_attempt=2, # 尝试两次 ) # 上面的方法调用等同于下面的较长形式: runnable_with_retries = RunnableRetry( bound=runnable, retry_exception_types=(ValueError,), max_attempt_number=2, wait_exponential_jitter=True ) 此逻辑可用于重试任何Runnable,包括一系列Runnables,但通常最好将重试的范围保持尽可能小。 例如,如果您有一系列Runnables,应该只重试可能失败的Runnable,而不是整个链条。 示例: .. code-block:: python from langchain_core.chat_models import ChatOpenAI from langchain_core.prompts import PromptTemplate template = PromptTemplate.from_template("告诉我一个关于{topic}的笑话。") model = ChatOpenAI(temperature=0.5) # 好的 chain = template | model.with_retry() # 不好的 chain = template | model retryable_chain = chain.with_retry() """ retry_exception_types: Tuple[Type[BaseException], ...] = (Exception,) """要重试的异常类型。默认情况下,所有异常都会被重试。 一般来说,你应该只在可能是临时的异常上重试,比如网络错误。 适合重试的良好异常包括所有服务器错误(5xx)和选定的客户端错误(4xx),比如429 Too Many Requests。""" wait_exponential_jitter: bool = True """是否将抖动添加到指数退避中。""" max_attempt_number: int = 3 """重试可运行的最大尝试次数。"""
[docs] @classmethod def get_lc_namespace(cls) -> List[str]: """获取langchain对象的命名空间。""" return ["langchain", "schema", "runnable"]
@property def _kwargs_retrying(self) -> Dict[str, Any]: kwargs: Dict[str, Any] = dict() if self.max_attempt_number: kwargs["stop"] = stop_after_attempt(self.max_attempt_number) if self.wait_exponential_jitter: kwargs["wait"] = wait_exponential_jitter() if self.retry_exception_types: kwargs["retry"] = retry_if_exception_type(self.retry_exception_types) return kwargs def _sync_retrying(self, **kwargs: Any) -> Retrying: return Retrying(**self._kwargs_retrying, **kwargs) def _async_retrying(self, **kwargs: Any) -> AsyncRetrying: return AsyncRetrying(**self._kwargs_retrying, **kwargs) def _patch_config( self, config: RunnableConfig, run_manager: "T", retry_state: RetryCallState, ) -> RunnableConfig: attempt = retry_state.attempt_number tag = "retry:attempt:{}".format(attempt) if attempt > 1 else None return patch_config(config, callbacks=run_manager.get_child(tag)) def _patch_config_list( self, config: List[RunnableConfig], run_manager: List["T"], retry_state: RetryCallState, ) -> List[RunnableConfig]: return [ self._patch_config(c, rm, retry_state) for c, rm in zip(config, run_manager) ] def _invoke( self, input: Input, run_manager: "CallbackManagerForChainRun", config: RunnableConfig, **kwargs: Any, ) -> Output: for attempt in self._sync_retrying(reraise=True): with attempt: result = super().invoke( input, self._patch_config(config, run_manager, attempt.retry_state), **kwargs, ) if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed: attempt.retry_state.set_result(result) return result
[docs] def invoke( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Output: return self._call_with_config(self._invoke, input, config, **kwargs)
async def _ainvoke( self, input: Input, run_manager: "AsyncCallbackManagerForChainRun", config: RunnableConfig, **kwargs: Any, ) -> Output: async for attempt in self._async_retrying(reraise=True): with attempt: result = await super().ainvoke( input, self._patch_config(config, run_manager, attempt.retry_state), **kwargs, ) if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed: attempt.retry_state.set_result(result) return result
[docs] async def ainvoke( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Output: return await self._acall_with_config(self._ainvoke, input, config, **kwargs)
def _batch( self, inputs: List[Input], run_manager: List["CallbackManagerForChainRun"], config: List[RunnableConfig], **kwargs: Any, ) -> List[Union[Output, Exception]]: results_map: Dict[int, Output] = {} def pending(iterable: List[U]) -> List[U]: return [item for idx, item in enumerate(iterable) if idx not in results_map] try: for attempt in self._sync_retrying(): with attempt: # Get the results of the inputs that have not succeeded yet. result = super().batch( pending(inputs), self._patch_config_list( pending(config), pending(run_manager), attempt.retry_state ), return_exceptions=True, **kwargs, ) # Register the results of the inputs that have succeeded. first_exception = None for i, r in enumerate(result): if isinstance(r, Exception): if not first_exception: first_exception = r continue results_map[i] = r # If any exception occurred, raise it, to retry the failed ones if first_exception: raise first_exception if ( attempt.retry_state.outcome and not attempt.retry_state.outcome.failed ): attempt.retry_state.set_result(result) except RetryError as e: try: result except UnboundLocalError: result = cast(List[Output], [e] * len(inputs)) outputs: List[Union[Output, Exception]] = [] for idx, _ in enumerate(inputs): if idx in results_map: outputs.append(results_map[idx]) else: outputs.append(result.pop(0)) return outputs
[docs] def batch( self, inputs: List[Input], config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, *, return_exceptions: bool = False, **kwargs: Any, ) -> List[Output]: return self._batch_with_config( self._batch, inputs, config, return_exceptions=return_exceptions, **kwargs )
async def _abatch( self, inputs: List[Input], run_manager: List["AsyncCallbackManagerForChainRun"], config: List[RunnableConfig], **kwargs: Any, ) -> List[Union[Output, Exception]]: results_map: Dict[int, Output] = {} def pending(iterable: List[U]) -> List[U]: return [item for idx, item in enumerate(iterable) if idx not in results_map] try: async for attempt in self._async_retrying(): with attempt: # Get the results of the inputs that have not succeeded yet. result = await super().abatch( pending(inputs), self._patch_config_list( pending(config), pending(run_manager), attempt.retry_state ), return_exceptions=True, **kwargs, ) # Register the results of the inputs that have succeeded. first_exception = None for i, r in enumerate(result): if isinstance(r, Exception): if not first_exception: first_exception = r continue results_map[i] = r # If any exception occurred, raise it, to retry the failed ones if first_exception: raise first_exception if ( attempt.retry_state.outcome and not attempt.retry_state.outcome.failed ): attempt.retry_state.set_result(result) except RetryError as e: try: result except UnboundLocalError: result = cast(List[Output], [e] * len(inputs)) outputs: List[Union[Output, Exception]] = [] for idx, _ in enumerate(inputs): if idx in results_map: outputs.append(results_map[idx]) else: outputs.append(result.pop(0)) return outputs
[docs] async def abatch( self, inputs: List[Input], config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, *, return_exceptions: bool = False, **kwargs: Any, ) -> List[Output]: return await self._abatch_with_config( self._abatch, inputs, config, return_exceptions=return_exceptions, **kwargs )
# stream() and transform() are not retried because retrying a stream # is not very intuitive.