Source code for langchain.chains.combine_documents.map_rerank

"""通过首先在文档上映射链,然后重新排列结果来合并文档。"""

from __future__ import annotations

from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union, cast

from langchain_core.callbacks import Callbacks
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
from langchain_core.runnables.config import RunnableConfig
from langchain_core.runnables.utils import create_model

from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.output_parsers.regex import RegexParser


[docs]class MapRerankDocumentsChain(BaseCombineDocumentsChain): """将文档通过在其上映射链条,然后重新排列结果。 该算法对每个输入文档调用一个LLMChain。预期LLMChain具有一个OutputParser,将结果解析为答案(`answer_key`)和分数(`rank_key`)。然后返回得分最高的答案。 示例: .. code-block:: python from langchain.chains import StuffDocumentsChain, LLMChain from langchain_core.prompts import PromptTemplate from langchain_community.llms import OpenAI from langchain.output_parsers.regex import RegexParser document_variable_name = "context" llm = OpenAI() # 这里的提示应该以`document_variable_name`作为输入变量 # 实际提示需要更加复杂,这只是一个示例。 prompt_template = ( "使用以下上下文告诉我水的化学式。输出你的答案和你的自信程度分数。上下文: {content}" ) output_parser = RegexParser( regex=r"(.*?) Score: (.*)", output_keys=["answer", "score"], ) prompt = PromptTemplate( template=prompt_template, input_variables=["context"], output_parser=output_parser, ) llm_chain = LLMChain(llm=llm, prompt=prompt) chain = MapRerankDocumentsChain( llm_chain=llm_chain, document_variable_name=document_variable_name, rank_key="score", answer_key="answer", ) """ llm_chain: LLMChain """对每个文档单独应用的链。""" document_variable_name: str """在llm_chain中放置文档的变量名。 如果在llm_chain中只有一个变量,则无需提供此变量。""" rank_key: str """在llm_chain的输出中键入以进行排序。""" answer_key: str """键入llm_chain的输出以返回答案。""" metadata_keys: Optional[List[str]] = None """从所选文档返回的附加元数据。""" return_intermediate_steps: bool = False """返回中间步骤。 中间步骤包括调用每个文档上的llm_chain的结果。""" class Config: """这个pydantic对象的配置。""" extra = Extra.forbid arbitrary_types_allowed = True
[docs] def get_output_schema( self, config: Optional[RunnableConfig] = None ) -> Type[BaseModel]: schema: Dict[str, Any] = { self.output_key: (str, None), } if self.return_intermediate_steps: schema["intermediate_steps"] = (List[str], None) if self.metadata_keys: schema.update({key: (Any, None) for key in self.metadata_keys}) return create_model("MapRerankOutput", **schema)
@property def output_keys(self) -> List[str]: """期望输入键。 :元数据 私有: """ _output_keys = super().output_keys if self.return_intermediate_steps: _output_keys = _output_keys + ["intermediate_steps"] if self.metadata_keys is not None: _output_keys += self.metadata_keys return _output_keys @root_validator() def validate_llm_output(cls, values: Dict) -> Dict: """验证组合链输出的是一个字典。""" output_parser = values["llm_chain"].prompt.output_parser if not isinstance(output_parser, RegexParser): raise ValueError( "Output parser of llm_chain should be a RegexParser," f" got {output_parser}" ) output_keys = output_parser.output_keys if values["rank_key"] not in output_keys: raise ValueError( f"Got {values['rank_key']} as key to rank on, but did not find " f"it in the llm_chain output keys ({output_keys})" ) if values["answer_key"] not in output_keys: raise ValueError( f"Got {values['answer_key']} as key to return, but did not find " f"it in the llm_chain output keys ({output_keys})" ) return values @root_validator(pre=True) def get_default_document_variable_name(cls, values: Dict) -> Dict: """获取默认文档变量名称,如果未提供。""" if "document_variable_name" not in values: llm_chain_variables = values["llm_chain"].prompt.input_variables if len(llm_chain_variables) == 1: values["document_variable_name"] = llm_chain_variables[0] else: raise ValueError( "document_variable_name must be provided if there are " "multiple llm_chain input_variables" ) else: llm_chain_variables = values["llm_chain"].prompt.input_variables if values["document_variable_name"] not in llm_chain_variables: raise ValueError( f"document_variable_name {values['document_variable_name']} was " f"not found in llm_chain input_variables: {llm_chain_variables}" ) return values
[docs] def combine_docs( self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any ) -> Tuple[str, dict]: """在地图重新排列的方式下合并文档。 通过首先在所有文档上映射第一个链,然后重新排列结果来进行合并。 参数: docs:要合并的文档列表 callbacks:要传递的回调函数 **kwargs:要传递给LLM调用的其他参数(例如文档之外的其他输入变量) 返回: 返回的第一个元素是单个字符串输出。返回的第二个元素是要返回的其他键的字典。 """ results = self.llm_chain.apply_and_parse( # FYI - this is parallelized and so it is fast. [{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs], callbacks=callbacks, ) return self._process_results(docs, results)
[docs] async def acombine_docs( self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any ) -> Tuple[str, dict]: """在地图重新排列的方式下合并文档。 通过首先在所有文档上映射第一个链,然后重新排列结果来进行合并。 参数: docs:要合并的文档列表 callbacks:要传递的回调函数 **kwargs:要传递给LLM调用的其他参数(例如文档之外的其他输入变量) 返回: 返回的第一个元素是单个字符串输出。返回的第二个元素是要返回的其他键的字典。 """ results = await self.llm_chain.aapply_and_parse( # FYI - this is parallelized and so it is fast. [{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs], callbacks=callbacks, ) return self._process_results(docs, results)
def _process_results( self, docs: List[Document], results: Sequence[Union[str, List[str], Dict[str, str]]], ) -> Tuple[str, dict]: typed_results = cast(List[dict], results) sorted_res = sorted( zip(typed_results, docs), key=lambda x: -int(x[0][self.rank_key]) ) output, document = sorted_res[0] extra_info = {} if self.metadata_keys is not None: for key in self.metadata_keys: extra_info[key] = document.metadata[key] if self.return_intermediate_steps: extra_info["intermediate_steps"] = results return output[self.answer_key], extra_info @property def _chain_type(self) -> str: return "map_rerank_documents_chain"