Source code for langchain.output_parsers.retry
from __future__ import annotations
from typing import Any, TypeVar
from langchain_core.exceptions import OutputParserException
from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.prompt_values import PromptValue
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
NAIVE_COMPLETION_RETRY = """Prompt:
{prompt}
Completion:
{completion}
Above, the Completion did not satisfy the constraints given in the Prompt.
Please try again:"""
NAIVE_COMPLETION_RETRY_WITH_ERROR = """Prompt:
{prompt}
Completion:
{completion}
Above, the Completion did not satisfy the constraints given in the Prompt.
Details: {error}
Please try again:"""
NAIVE_RETRY_PROMPT = PromptTemplate.from_template(NAIVE_COMPLETION_RETRY)
NAIVE_RETRY_WITH_ERROR_PROMPT = PromptTemplate.from_template(
NAIVE_COMPLETION_RETRY_WITH_ERROR
)
T = TypeVar("T")
[docs]class RetryOutputParser(BaseOutputParser[T]):
"""封装一个解析器并尝试修复解析错误。
通过将原始提示和完成内容传递给另一个LLM,并告诉它完成内容未满足提示中的条件来实现这一点。"""
parser: BaseOutputParser[T]
"""用于解析输出的解析器。"""
# Should be an LLMChain but we want to avoid top-level imports from langchain.chains
retry_chain: Any
"""用于重试完成的LLMChain。"""
max_retries: int = 1
"""解析重试的最大次数。"""
[docs] @classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
parser: BaseOutputParser[T],
prompt: BasePromptTemplate = NAIVE_RETRY_PROMPT,
max_retries: int = 1,
) -> RetryOutputParser[T]:
"""从语言模型和解析器创建一个RetryOutputParser。
参数:
llm:用于修复的llm
parser:用于解析的解析器
prompt:用于修复的提示
max_retries:最大重试次数。
返回:
RetryOutputParser
"""
from langchain.chains.llm import LLMChain
chain = LLMChain(llm=llm, prompt=prompt)
return cls(parser=parser, retry_chain=chain, max_retries=max_retries)
[docs] def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T:
"""解析使用包装解析器的LLM调用的输出。
参数:
completion: 要解析的链完成。
prompt_value: 用于解析完成的提示。
返回:
解析后的完成。
"""
retries = 0
while retries <= self.max_retries:
try:
return self.parser.parse(completion)
except OutputParserException as e:
if retries == self.max_retries:
raise e
else:
retries += 1
completion = self.retry_chain.run(
prompt=prompt_value.to_string(), completion=completion
)
raise OutputParserException("Failed to parse")
[docs] async def aparse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T:
"""解析使用包装解析器的LLM调用的输出。
参数:
completion: 要解析的链完成。
prompt_value: 用于解析完成的提示。
返回:
解析后的完成。
"""
retries = 0
while retries <= self.max_retries:
try:
return await self.parser.aparse(completion)
except OutputParserException as e:
if retries == self.max_retries:
raise e
else:
retries += 1
completion = await self.retry_chain.arun(
prompt=prompt_value.to_string(), completion=completion
)
raise OutputParserException("Failed to parse")
[docs] def parse(self, completion: str) -> T:
raise NotImplementedError(
"This OutputParser can only be called by the `parse_with_prompt` method."
)
@property
def _type(self) -> str:
return "retry"
[docs]class RetryWithErrorOutputParser(BaseOutputParser[T]):
"""封装一个解析器并尝试修复解析错误。
通过将原始提示、完成内容和引发的错误传递给另一个语言模型,并告诉它完成内容无效,并引发了给定的错误来实现。与RetryOutputParser不同的是,这个实现将引发的错误提供给LLM,理论上应该提供更多信息以便修复错误。"""
parser: BaseOutputParser[T]
"""用于解析输出的解析器。"""
# Should be an LLMChain but we want to avoid top-level imports from langchain.chains
retry_chain: Any
"""用于重试完成的LLMChain。"""
max_retries: int = 1
"""解析重试的最大次数。"""
[docs] @classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
parser: BaseOutputParser[T],
prompt: BasePromptTemplate = NAIVE_RETRY_WITH_ERROR_PROMPT,
max_retries: int = 1,
) -> RetryWithErrorOutputParser[T]:
"""从LLM创建一个RetryWithErrorOutputParser。
参数:
llm:用于重试完成的LLM。
parser:用于解析输出的解析器。
prompt:用于重试完成的提示。
max_retries:重试完成的最大次数。
返回:
一个RetryWithErrorOutputParser。
"""
from langchain.chains.llm import LLMChain
chain = LLMChain(llm=llm, prompt=prompt)
return cls(parser=parser, retry_chain=chain, max_retries=max_retries)
[docs] def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T:
retries = 0
while retries <= self.max_retries:
try:
return self.parser.parse(completion)
except OutputParserException as e:
if retries == self.max_retries:
raise e
else:
retries += 1
completion = self.retry_chain.run(
prompt=prompt_value.to_string(),
completion=completion,
error=repr(e),
)
raise OutputParserException("Failed to parse")
[docs] async def aparse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T:
retries = 0
while retries <= self.max_retries:
try:
return await self.parser.aparse(completion)
except OutputParserException as e:
if retries == self.max_retries:
raise e
else:
retries += 1
completion = await self.retry_chain.arun(
prompt=prompt_value.to_string(),
completion=completion,
error=repr(e),
)
raise OutputParserException("Failed to parse")
[docs] def parse(self, completion: str) -> T:
raise NotImplementedError(
"This OutputParser can only be called by the `parse_with_prompt` method."
)
@property
def _type(self) -> str:
return "retry_with_error"