from __future__ import annotations
import re
from abc import abstractmethod
from typing import Any, Dict, List, Optional, Sequence, Tuple
import numpy as np
from langchain_core.callbacks import (
CallbackManagerForChainRun,
)
from langchain_core.language_models import BaseLanguageModel
from langchain_core.outputs import Generation
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain_core.retrievers import BaseRetriever
from langchain.chains.base import Chain
from langchain.chains.flare.prompts import (
PROMPT,
QUESTION_GENERATOR_PROMPT,
FinishedOutputParser,
)
from langchain.chains.llm import LLMChain
class _ResponseChain(LLMChain):
"""生成响应的链的基类。"""
prompt: BasePromptTemplate = PROMPT
@classmethod
def is_lc_serializable(cls) -> bool:
return False
@property
def input_keys(self) -> List[str]:
return self.prompt.input_variables
def generate_tokens_and_log_probs(
self,
_input: Dict[str, Any],
*,
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Tuple[Sequence[str], Sequence[float]]:
llm_result = self.generate([_input], run_manager=run_manager)
return self._extract_tokens_and_log_probs(llm_result.generations[0])
@abstractmethod
def _extract_tokens_and_log_probs(
self, generations: List[Generation]
) -> Tuple[Sequence[str], Sequence[float]]:
"""从响应中提取令牌和日志概率。"""
class _OpenAIResponseChain(_ResponseChain):
"""生成从用户输入和上下文中获取响应的链。"""
llm: BaseLanguageModel
def _extract_tokens_and_log_probs(
self, generations: List[Generation]
) -> Tuple[Sequence[str], Sequence[float]]:
tokens = []
log_probs = []
for gen in generations:
if gen.generation_info is None:
raise ValueError
tokens.extend(gen.generation_info["logprobs"]["tokens"])
log_probs.extend(gen.generation_info["logprobs"]["token_logprobs"])
return tokens, log_probs
[docs]class QuestionGeneratorChain(LLMChain):
"""从不确定范围生成问题的链。"""
prompt: BasePromptTemplate = QUESTION_GENERATOR_PROMPT
"""链的提示模板。"""
[docs] @classmethod
def is_lc_serializable(cls) -> bool:
return False
@property
def input_keys(self) -> List[str]:
"""为链输入密钥。"""
return ["user_input", "context", "response"]
def _low_confidence_spans(
tokens: Sequence[str],
log_probs: Sequence[float],
min_prob: float,
min_token_gap: int,
num_pad_tokens: int,
) -> List[str]:
_low_idx = np.where(np.exp(log_probs) < min_prob)[0]
low_idx = [i for i in _low_idx if re.search(r"\w", tokens[i])]
if len(low_idx) == 0:
return []
spans = [[low_idx[0], low_idx[0] + num_pad_tokens + 1]]
for i, idx in enumerate(low_idx[1:]):
end = idx + num_pad_tokens + 1
if idx - low_idx[i] < min_token_gap:
spans[-1][1] = end
else:
spans.append([idx, end])
return ["".join(tokens[start:end]) for start, end in spans]
[docs]class FlareChain(Chain):
"""结合检索器、问题生成器和回答生成器的链条。"""
question_generator_chain: QuestionGeneratorChain
"""从不确定范围生成问题的链。"""
response_chain: _ResponseChain
"""生成从用户输入和上下文中获取响应的链。"""
output_parser: FinishedOutputParser = Field(default_factory=FinishedOutputParser)
"""解析器,用于确定链条是否已完成。"""
retriever: BaseRetriever
"""从用户输入中检索相关文档的检索器。"""
min_prob: float = 0.2
"""被视为低置信度的令牌的最小概率。"""
min_token_gap: int = 5
"""两个低置信度跨度之间的最小令牌数。"""
num_pad_tokens: int = 2
"""在低置信度跨度周围填充的标记数量。"""
max_iter: int = 10
"""最大迭代次数。"""
start_with_retrieval: bool = True
"""是否从检索开始。"""
@property
def input_keys(self) -> List[str]:
"""为链输入密钥。"""
return ["user_input"]
@property
def output_keys(self) -> List[str]:
"""输出链的密钥。"""
return ["response"]
def _do_generation(
self,
questions: List[str],
user_input: str,
response: str,
_run_manager: CallbackManagerForChainRun,
) -> Tuple[str, bool]:
callbacks = _run_manager.get_child()
docs = []
for question in questions:
docs.extend(self.retriever.invoke(question))
context = "\n\n".join(d.page_content for d in docs)
result = self.response_chain.predict(
user_input=user_input,
context=context,
response=response,
callbacks=callbacks,
)
marginal, finished = self.output_parser.parse(result)
return marginal, finished
def _do_retrieval(
self,
low_confidence_spans: List[str],
_run_manager: CallbackManagerForChainRun,
user_input: str,
response: str,
initial_response: str,
) -> Tuple[str, bool]:
question_gen_inputs = [
{
"user_input": user_input,
"current_response": initial_response,
"uncertain_span": span,
}
for span in low_confidence_spans
]
callbacks = _run_manager.get_child()
question_gen_outputs = self.question_generator_chain.apply(
question_gen_inputs, callbacks=callbacks
)
questions = [
output[self.question_generator_chain.output_keys[0]]
for output in question_gen_outputs
]
_run_manager.on_text(
f"Generated Questions: {questions}", color="yellow", end="\n"
)
return self._do_generation(questions, user_input, response, _run_manager)
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
user_input = inputs[self.input_keys[0]]
response = ""
for i in range(self.max_iter):
_run_manager.on_text(
f"Current Response: {response}", color="blue", end="\n"
)
_input = {"user_input": user_input, "context": "", "response": response}
tokens, log_probs = self.response_chain.generate_tokens_and_log_probs(
_input, run_manager=_run_manager
)
low_confidence_spans = _low_confidence_spans(
tokens,
log_probs,
self.min_prob,
self.min_token_gap,
self.num_pad_tokens,
)
initial_response = response.strip() + " " + "".join(tokens)
if not low_confidence_spans:
response = initial_response
final_response, finished = self.output_parser.parse(response)
if finished:
return {self.output_keys[0]: final_response}
continue
marginal, finished = self._do_retrieval(
low_confidence_spans,
_run_manager,
user_input,
response,
initial_response,
)
response = response.strip() + " " + marginal
if finished:
break
return {self.output_keys[0]: response}
[docs] @classmethod
def from_llm(
cls, llm: BaseLanguageModel, max_generation_len: int = 32, **kwargs: Any
) -> FlareChain:
"""从语言模型创建一个FlareChain。
参数:
llm: 要使用的语言模型。
max_generation_len: 生成响应的最大长度。
**kwargs: 传递给构造函数的其他参数。
返回:
具有给定语言模型的FlareChain类。
"""
try:
from langchain_openai import OpenAI
except ImportError:
raise ImportError(
"OpenAI is required for FlareChain. "
"Please install langchain-openai."
"pip install langchain-openai"
)
question_gen_chain = QuestionGeneratorChain(llm=llm)
response_llm = OpenAI(
max_tokens=max_generation_len, model_kwargs={"logprobs": 1}, temperature=0
)
response_chain = _OpenAIResponseChain(llm=response_llm)
return cls(
question_generator_chain=question_gen_chain,
response_chain=response_chain,
**kwargs,
)