"""通过首先在文档上映射链,然后组合结果来合并文档。"""
from __future__ import annotations
from typing import Any, Dict, List, Optional, Tuple, Type
from langchain_core.callbacks import Callbacks
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
from langchain_core.runnables.config import RunnableConfig
from langchain_core.runnables.utils import create_model
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.reduce import ReduceDocumentsChain
from langchain.chains.llm import LLMChain
[docs]class MapReduceDocumentsChain(BaseCombineDocumentsChain):
"""通过在文档上映射链,然后组合结果来合并文档。
首先对每个文档分别调用`llm_chain`,传入`page_content`和任何其他kwargs。这是`map`步骤。
然后在`reduce`步骤中处理该`map`步骤的结果。这可能是一个ReduceDocumentsChain。
示例:
.. code-block:: python
from langchain.chains import (
StuffDocumentsChain,
LLMChain,
ReduceDocumentsChain,
MapReduceDocumentsChain,
)
from langchain_core.prompts import PromptTemplate
from langchain_community.llms import OpenAI
# 这控制每个文档的格式。具体来说,它将被传递给`format_document` - 有关更多详细信息,请参见该函数。
document_prompt = PromptTemplate(
input_variables=["page_content"],
template="{page_content}"
)
document_variable_name = "context"
llm = OpenAI()
# 这里的提示应该将`document_variable_name`作为输入变量
prompt = PromptTemplate.from_template(
"Summarize this content: {context}"
)
llm_chain = LLMChain(llm=llm, prompt=prompt)
# 现在我们定义如何组合这些摘要
reduce_prompt = PromptTemplate.from_template(
"Combine these summaries: {context}"
)
reduce_llm_chain = LLMChain(llm=llm, prompt=reduce_prompt)
combine_documents_chain = StuffDocumentsChain(
llm_chain=reduce_llm_chain,
document_prompt=document_prompt,
document_variable_name=document_variable_name
)
reduce_documents_chain = ReduceDocumentsChain(
combine_documents_chain=combine_documents_chain,
)
chain = MapReduceDocumentsChain(
llm_chain=llm_chain,
reduce_documents_chain=reduce_documents_chain,
)
# 如果需要的话,我们也可以传入collapse_documents_chain,这是专门用于在最终调用之前折叠文档的。
prompt = PromptTemplate.from_template(
"Collapse this content: {context}"
)
llm_chain = LLMChain(llm=llm, prompt=prompt)
collapse_documents_chain = StuffDocumentsChain(
llm_chain=llm_chain,
document_prompt=document_prompt,
document_variable_name=document_variable_name
)
reduce_documents_chain = ReduceDocumentsChain(
combine_documents_chain=combine_documents_chain,
collapse_documents_chain=collapse_documents_chain,
)
chain = MapReduceDocumentsChain(
llm_chain=llm_chain,
reduce_documents_chain=reduce_documents_chain,
)"""
llm_chain: LLMChain
"""对每个文档单独应用的链。"""
reduce_documents_chain: BaseCombineDocumentsChain
"""用于将应用于每个文档的`llm_chain`结果减少的链。
通常是ReduceDocumentChain或StuffDocumentChain。"""
document_variable_name: str
"""在llm_chain中放置文档的变量名称。
如果在llm_chain中只有一个变量,则无需提供此变量。"""
return_intermediate_steps: bool = False
"""返回输出中映射步骤的结果。"""
[docs] def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
if self.return_intermediate_steps:
return create_model(
"MapReduceDocumentsOutput",
**{
self.output_key: (str, None),
"intermediate_steps": (List[str], None),
}, # type: ignore[call-overload]
)
return super().get_output_schema(config)
@property
def output_keys(self) -> List[str]:
"""期望输入键。
:元数据 私有:
"""
_output_keys = super().output_keys
if self.return_intermediate_steps:
_output_keys = _output_keys + ["intermediate_steps"]
return _output_keys
class Config:
"""这个pydantic对象的配置。"""
extra = Extra.forbid
arbitrary_types_allowed = True
@root_validator(pre=True)
def get_reduce_chain(cls, values: Dict) -> Dict:
"""为了向后兼容。"""
if "combine_document_chain" in values:
if "reduce_documents_chain" in values:
raise ValueError(
"Both `reduce_documents_chain` and `combine_document_chain` "
"cannot be provided at the same time. `combine_document_chain` "
"is deprecated, please only provide `reduce_documents_chain`"
)
combine_chain = values["combine_document_chain"]
collapse_chain = values.get("collapse_document_chain")
reduce_chain = ReduceDocumentsChain(
combine_documents_chain=combine_chain,
collapse_documents_chain=collapse_chain,
)
values["reduce_documents_chain"] = reduce_chain
del values["combine_document_chain"]
if "collapse_document_chain" in values:
del values["collapse_document_chain"]
return values
@root_validator(pre=True)
def get_return_intermediate_steps(cls, values: Dict) -> Dict:
"""为了向后兼容。"""
if "return_map_steps" in values:
values["return_intermediate_steps"] = values["return_map_steps"]
del values["return_map_steps"]
return values
@root_validator(pre=True)
def get_default_document_variable_name(cls, values: Dict) -> Dict:
"""获取默认文档变量名称,如果未提供。"""
if "document_variable_name" not in values:
llm_chain_variables = values["llm_chain"].prompt.input_variables
if len(llm_chain_variables) == 1:
values["document_variable_name"] = llm_chain_variables[0]
else:
raise ValueError(
"document_variable_name must be provided if there are "
"multiple llm_chain input_variables"
)
else:
llm_chain_variables = values["llm_chain"].prompt.input_variables
if values["document_variable_name"] not in llm_chain_variables:
raise ValueError(
f"document_variable_name {values['document_variable_name']} was "
f"not found in llm_chain input_variables: {llm_chain_variables}"
)
return values
@property
def collapse_document_chain(self) -> BaseCombineDocumentsChain:
"""保留以确保向后兼容性。"""
if isinstance(self.reduce_documents_chain, ReduceDocumentsChain):
if self.reduce_documents_chain.collapse_documents_chain:
return self.reduce_documents_chain.collapse_documents_chain
else:
return self.reduce_documents_chain.combine_documents_chain
else:
raise ValueError(
f"`reduce_documents_chain` is of type "
f"{type(self.reduce_documents_chain)} so it does not have "
f"this attribute."
)
@property
def combine_document_chain(self) -> BaseCombineDocumentsChain:
"""保留以确保向后兼容性。"""
if isinstance(self.reduce_documents_chain, ReduceDocumentsChain):
return self.reduce_documents_chain.combine_documents_chain
else:
raise ValueError(
f"`reduce_documents_chain` is of type "
f"{type(self.reduce_documents_chain)} so it does not have "
f"this attribute."
)
[docs] def combine_docs(
self,
docs: List[Document],
token_max: Optional[int] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> Tuple[str, dict]:
"""以MapReduce的方式合并文档。
首先通过映射链将所有文档进行合并,然后减少结果。
如果需要的话,可以递归地进行这种减少操作(如果有许多文档)。
"""
map_results = self.llm_chain.apply(
# FYI - this is parallelized and so it is fast.
[{self.document_variable_name: d.page_content, **kwargs} for d in docs],
callbacks=callbacks,
)
question_result_key = self.llm_chain.output_key
result_docs = [
Document(page_content=r[question_result_key], metadata=docs[i].metadata)
# This uses metadata from the docs, and the textual results from `results`
for i, r in enumerate(map_results)
]
result, extra_return_dict = self.reduce_documents_chain.combine_docs(
result_docs, token_max=token_max, callbacks=callbacks, **kwargs
)
if self.return_intermediate_steps:
intermediate_steps = [r[question_result_key] for r in map_results]
extra_return_dict["intermediate_steps"] = intermediate_steps
return result, extra_return_dict
[docs] async def acombine_docs(
self,
docs: List[Document],
token_max: Optional[int] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> Tuple[str, dict]:
"""以MapReduce的方式合并文档。
首先通过映射链将所有文档进行合并,然后减少结果。
如果需要的话,可以递归地进行这种减少操作(如果有许多文档)。
"""
map_results = await self.llm_chain.aapply(
# FYI - this is parallelized and so it is fast.
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs],
callbacks=callbacks,
)
question_result_key = self.llm_chain.output_key
result_docs = [
Document(page_content=r[question_result_key], metadata=docs[i].metadata)
# This uses metadata from the docs, and the textual results from `results`
for i, r in enumerate(map_results)
]
result, extra_return_dict = await self.reduce_documents_chain.acombine_docs(
result_docs, token_max=token_max, callbacks=callbacks, **kwargs
)
if self.return_intermediate_steps:
intermediate_steps = [r[question_result_key] for r in map_results]
extra_return_dict["intermediate_steps"] = intermediate_steps
return result, extra_return_dict
@property
def _chain_type(self) -> str:
return "map_reduce_documents_chain"