"""通过首先在文档上映射链,然后重新排列结果来合并文档。"""
from __future__ import annotations
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union, cast
from langchain_core.callbacks import Callbacks
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
from langchain_core.runnables.config import RunnableConfig
from langchain_core.runnables.utils import create_model
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.output_parsers.regex import RegexParser
[docs]class MapRerankDocumentsChain(BaseCombineDocumentsChain):
"""将文档通过在其上映射链条,然后重新排列结果。
该算法对每个输入文档调用一个LLMChain。预期LLMChain具有一个OutputParser,将结果解析为答案(`answer_key`)和分数(`rank_key`)。然后返回得分最高的答案。
示例:
.. code-block:: python
from langchain.chains import StuffDocumentsChain, LLMChain
from langchain_core.prompts import PromptTemplate
from langchain_community.llms import OpenAI
from langchain.output_parsers.regex import RegexParser
document_variable_name = "context"
llm = OpenAI()
# 这里的提示应该以`document_variable_name`作为输入变量
# 实际提示需要更加复杂,这只是一个示例。
prompt_template = (
"使用以下上下文告诉我水的化学式。输出你的答案和你的自信程度分数。上下文: {content}"
)
output_parser = RegexParser(
regex=r"(.*?)
Score: (.*)",
output_keys=["answer", "score"],
)
prompt = PromptTemplate(
template=prompt_template,
input_variables=["context"],
output_parser=output_parser,
)
llm_chain = LLMChain(llm=llm, prompt=prompt)
chain = MapRerankDocumentsChain(
llm_chain=llm_chain,
document_variable_name=document_variable_name,
rank_key="score",
answer_key="answer",
)
"""
llm_chain: LLMChain
"""对每个文档单独应用的链。"""
document_variable_name: str
"""在llm_chain中放置文档的变量名。
如果在llm_chain中只有一个变量,则无需提供此变量。"""
rank_key: str
"""在llm_chain的输出中键入以进行排序。"""
answer_key: str
"""键入llm_chain的输出以返回答案。"""
metadata_keys: Optional[List[str]] = None
"""从所选文档返回的附加元数据。"""
return_intermediate_steps: bool = False
"""返回中间步骤。
中间步骤包括调用每个文档上的llm_chain的结果。"""
class Config:
"""这个pydantic对象的配置。"""
extra = Extra.forbid
arbitrary_types_allowed = True
[docs] def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
schema: Dict[str, Any] = {
self.output_key: (str, None),
}
if self.return_intermediate_steps:
schema["intermediate_steps"] = (List[str], None)
if self.metadata_keys:
schema.update({key: (Any, None) for key in self.metadata_keys})
return create_model("MapRerankOutput", **schema)
@property
def output_keys(self) -> List[str]:
"""期望输入键。
:元数据 私有:
"""
_output_keys = super().output_keys
if self.return_intermediate_steps:
_output_keys = _output_keys + ["intermediate_steps"]
if self.metadata_keys is not None:
_output_keys += self.metadata_keys
return _output_keys
@root_validator()
def validate_llm_output(cls, values: Dict) -> Dict:
"""验证组合链输出的是一个字典。"""
output_parser = values["llm_chain"].prompt.output_parser
if not isinstance(output_parser, RegexParser):
raise ValueError(
"Output parser of llm_chain should be a RegexParser,"
f" got {output_parser}"
)
output_keys = output_parser.output_keys
if values["rank_key"] not in output_keys:
raise ValueError(
f"Got {values['rank_key']} as key to rank on, but did not find "
f"it in the llm_chain output keys ({output_keys})"
)
if values["answer_key"] not in output_keys:
raise ValueError(
f"Got {values['answer_key']} as key to return, but did not find "
f"it in the llm_chain output keys ({output_keys})"
)
return values
@root_validator(pre=True)
def get_default_document_variable_name(cls, values: Dict) -> Dict:
"""获取默认文档变量名称,如果未提供。"""
if "document_variable_name" not in values:
llm_chain_variables = values["llm_chain"].prompt.input_variables
if len(llm_chain_variables) == 1:
values["document_variable_name"] = llm_chain_variables[0]
else:
raise ValueError(
"document_variable_name must be provided if there are "
"multiple llm_chain input_variables"
)
else:
llm_chain_variables = values["llm_chain"].prompt.input_variables
if values["document_variable_name"] not in llm_chain_variables:
raise ValueError(
f"document_variable_name {values['document_variable_name']} was "
f"not found in llm_chain input_variables: {llm_chain_variables}"
)
return values
[docs] def combine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]:
"""在地图重新排列的方式下合并文档。
通过首先在所有文档上映射第一个链,然后重新排列结果来进行合并。
参数:
docs:要合并的文档列表
callbacks:要传递的回调函数
**kwargs:要传递给LLM调用的其他参数(例如文档之外的其他输入变量)
返回:
返回的第一个元素是单个字符串输出。返回的第二个元素是要返回的其他键的字典。
"""
results = self.llm_chain.apply_and_parse(
# FYI - this is parallelized and so it is fast.
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs],
callbacks=callbacks,
)
return self._process_results(docs, results)
[docs] async def acombine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]:
"""在地图重新排列的方式下合并文档。
通过首先在所有文档上映射第一个链,然后重新排列结果来进行合并。
参数:
docs:要合并的文档列表
callbacks:要传递的回调函数
**kwargs:要传递给LLM调用的其他参数(例如文档之外的其他输入变量)
返回:
返回的第一个元素是单个字符串输出。返回的第二个元素是要返回的其他键的字典。
"""
results = await self.llm_chain.aapply_and_parse(
# FYI - this is parallelized and so it is fast.
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs],
callbacks=callbacks,
)
return self._process_results(docs, results)
def _process_results(
self,
docs: List[Document],
results: Sequence[Union[str, List[str], Dict[str, str]]],
) -> Tuple[str, dict]:
typed_results = cast(List[dict], results)
sorted_res = sorted(
zip(typed_results, docs), key=lambda x: -int(x[0][self.rank_key])
)
output, document = sorted_res[0]
extra_info = {}
if self.metadata_keys is not None:
for key in self.metadata_keys:
extra_info[key] = document.metadata[key]
if self.return_intermediate_steps:
extra_info["intermediate_steps"] = results
return output[self.answer_key], extra_info
@property
def _chain_type(self) -> str:
return "map_rerank_documents_chain"