"""通过递归将许多文档合并在一起。"""
from __future__ import annotations
from typing import Any, Callable, List, Optional, Protocol, Tuple
from langchain_core.callbacks import Callbacks
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import Extra
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
[docs]class CombineDocsProtocol(Protocol):
"""合并文档方法的接口。"""
def __call__(self, docs: List[Document], **kwargs: Any) -> str:
"""合并文档方法的接口。"""
[docs]class AsyncCombineDocsProtocol(Protocol):
"""合并文档方法的接口。"""
async def __call__(self, docs: List[Document], **kwargs: Any) -> str:
"""combine_docs方法的异步接口。"""
[docs]def split_list_of_docs(
docs: List[Document], length_func: Callable, token_max: int, **kwargs: Any
) -> List[List[Document]]:
"""将文档分割成满足累积长度约束的子集。
参数:
docs:完整的文档列表。
length_func:用于计算一组文档的累积长度的函数。
token_max:文档子集的最大累积长度。
**kwargs:传递给每次调用length_func的任意额外关键字参数。
返回:
一个列表,其中包含多个列表,每个列表包含一组文档。
"""
new_result_doc_list = []
_sub_result_docs = []
for doc in docs:
_sub_result_docs.append(doc)
_num_tokens = length_func(_sub_result_docs, **kwargs)
if _num_tokens > token_max:
if len(_sub_result_docs) == 1:
raise ValueError(
"A single document was longer than the context length,"
" we cannot handle this."
)
new_result_doc_list.append(_sub_result_docs[:-1])
_sub_result_docs = _sub_result_docs[-1:]
new_result_doc_list.append(_sub_result_docs)
return new_result_doc_list
[docs]def collapse_docs(
docs: List[Document],
combine_document_func: CombineDocsProtocol,
**kwargs: Any,
) -> Document:
"""对一组文档执行合并函数,并合并它们的元数据。
参数:
docs: 要合并的文档列表。
combine_document_func: 一个函数,接受一个文档列表和可选的额外关键字参数,并将它们合并成一个字符串。
**kwargs: 要传递给combine_document_func的任意额外关键字参数。
返回值:
一个单一的文档,其中包含combine_document_func的输出作为页面内容,以及所有输入文档的合并元数据。所有元数据值都是字符串,对于文档之间存在重叠键的情况,值通过", "连接。
"""
result = combine_document_func(docs, **kwargs)
combined_metadata = {k: str(v) for k, v in docs[0].metadata.items()}
for doc in docs[1:]:
for k, v in doc.metadata.items():
if k in combined_metadata:
combined_metadata[k] += f", {v}"
else:
combined_metadata[k] = str(v)
return Document(page_content=result, metadata=combined_metadata)
[docs]async def acollapse_docs(
docs: List[Document],
combine_document_func: AsyncCombineDocsProtocol,
**kwargs: Any,
) -> Document:
"""对一组文档执行合并函数,并合并它们的元数据。
参数:
docs: 要合并的文档列表。
combine_document_func: 一个函数,接受一个文档列表和可选的额外关键字参数,并将它们合并成一个字符串。
**kwargs: 要传递给combine_document_func的任意额外关键字参数。
返回值:
一个单一的文档,其中包含combine_document_func的输出作为页面内容,以及所有输入文档的合并元数据。所有元数据值都是字符串,对于文档之间存在重叠键的情况,值通过", "连接。
"""
result = await combine_document_func(docs, **kwargs)
combined_metadata = {k: str(v) for k, v in docs[0].metadata.items()}
for doc in docs[1:]:
for k, v in doc.metadata.items():
if k in combined_metadata:
combined_metadata[k] += f", {v}"
else:
combined_metadata[k] = str(v)
return Document(page_content=result, metadata=combined_metadata)
[docs]class ReduceDocumentsChain(BaseCombineDocumentsChain):
"""将文档通过递归方式合并。
这涉及到
- combine_documents_chain
- collapse_documents_chain
`combine_documents_chain`始终提供。这是最终调用的链。
我们将所有先前的结果传递给这个链,这个链的输出作为最终结果返回。
如果传入的文档太多,无法一次全部传递给`combine_documents_chain`,则会使用`collapse_documents_chain`。在这种情况下,
会对文档的尽可能大的组进行递归调用`collapse_documents_chain`。
示例:
.. code-block:: python
from langchain.chains import (
StuffDocumentsChain, LLMChain, ReduceDocumentsChain
)
from langchain_core.prompts import PromptTemplate
from langchain_community.llms import OpenAI
# 这控制每个文档的格式。具体来说,
# 它将被传递给`format_document` - 查看该函数以获取更多详细信息。
document_prompt = PromptTemplate(
input_variables=["page_content"],
template="{page_content}"
)
document_variable_name = "context"
llm = OpenAI()
# 这里的提示应该将`document_variable_name`作为输入变量
prompt = PromptTemplate.from_template(
"总结这段内容: {context}"
)
llm_chain = LLMChain(llm=llm, prompt=prompt)
combine_documents_chain = StuffDocumentsChain(
llm_chain=llm_chain,
document_prompt=document_prompt,
document_variable_name=document_variable_name
)
chain = ReduceDocumentsChain(
combine_documents_chain=combine_documents_chain,
)
# 如果需要,我们也可以传入collapse_documents_chain
# 这是专门用于在最终调用之前折叠文档的
prompt = PromptTemplate.from_template(
"折叠这段内容: {context}"
)
llm_chain = LLMChain(llm=llm, prompt=prompt)
collapse_documents_chain = StuffDocumentsChain(
llm_chain=llm_chain,
document_prompt=document_prompt,
document_variable_name=document_variable_name
)
chain = ReduceDocumentsChain(
combine_documents_chain=combine_documents_chain,
collapse_documents_chain=collapse_documents_chain,
)"""
combine_documents_chain: BaseCombineDocumentsChain
"""最终链用于合并文档。
通常是一个StuffDocumentsChain。"""
collapse_documents_chain: Optional[BaseCombineDocumentsChain] = None
"""用于折叠文档以确保它们都能适应的链。
如果为None,则将使用combine_documents_chain。
通常是StuffDocumentsChain。"""
token_max: int = 3000
"""将文档分组为的最大令牌数。例如,如果设置为3000,则在尝试将它们组合成较小的块之前,文档将被分组为不超过3000个令牌的块。"""
collapse_max_retries: Optional[int] = None
"""将文档折叠以适应token_max的最大重试次数。
如果为None,则将继续尝试将文档折叠以适应token_max。
否则,在达到最大次数后,将抛出错误。"""
class Config:
"""此pydantic对象的配置。"""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def _collapse_chain(self) -> BaseCombineDocumentsChain:
if self.collapse_documents_chain is not None:
return self.collapse_documents_chain
else:
return self.combine_documents_chain
[docs] def combine_docs(
self,
docs: List[Document],
token_max: Optional[int] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> Tuple[str, dict]:
"""递归地合并多个文档。
参数:
docs:需要合并的文档列表,假设每个文档都少于`token_max`。
token_max:递归创建文档组,每组文档的令牌数少于此数字。
callbacks:要传递的回调函数
**kwargs:要传递给LLM调用的其他参数(例如文档之外的其他输入变量)
返回值:
返回的第一个元素是单个字符串输出。返回的第二个元素是要返回的其他键的字典。
"""
result_docs, extra_return_dict = self._collapse(
docs, token_max=token_max, callbacks=callbacks, **kwargs
)
return self.combine_documents_chain.combine_docs(
docs=result_docs, callbacks=callbacks, **kwargs
)
[docs] async def acombine_docs(
self,
docs: List[Document],
token_max: Optional[int] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> Tuple[str, dict]:
"""异步递归合并多个文档。
参数:
docs:要合并的文档列表,假设每个文档都少于`token_max`。
token_max:递归创建文档组,每个组的令牌数少于此数字。
callbacks:要传递的回调函数
**kwargs:要传递给LLM调用的其他参数(如文档之外的其他输入变量)
返回:
返回的第一个元素是单个字符串输出。返回的第二个元素是要返回的其他键的字典。
"""
result_docs, extra_return_dict = await self._acollapse(
docs, token_max=token_max, callbacks=callbacks, **kwargs
)
return await self.combine_documents_chain.acombine_docs(
docs=result_docs, callbacks=callbacks, **kwargs
)
def _collapse(
self,
docs: List[Document],
token_max: Optional[int] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> Tuple[List[Document], dict]:
result_docs = docs
length_func = self.combine_documents_chain.prompt_length
num_tokens = length_func(result_docs, **kwargs)
def _collapse_docs_func(docs: List[Document], **kwargs: Any) -> str:
return self._collapse_chain.run(
input_documents=docs, callbacks=callbacks, **kwargs
)
_token_max = token_max or self.token_max
retries: int = 0
while num_tokens is not None and num_tokens > _token_max:
new_result_doc_list = split_list_of_docs(
result_docs, length_func, _token_max, **kwargs
)
result_docs = []
for docs in new_result_doc_list:
new_doc = collapse_docs(docs, _collapse_docs_func, **kwargs)
result_docs.append(new_doc)
num_tokens = length_func(result_docs, **kwargs)
retries += 1
if self.collapse_max_retries and retries == self.collapse_max_retries:
raise ValueError(
f"Exceed {self.collapse_max_retries} tries to \
collapse document to {_token_max} tokens."
)
return result_docs, {}
async def _acollapse(
self,
docs: List[Document],
token_max: Optional[int] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> Tuple[List[Document], dict]:
result_docs = docs
length_func = self.combine_documents_chain.prompt_length
num_tokens = length_func(result_docs, **kwargs)
async def _collapse_docs_func(docs: List[Document], **kwargs: Any) -> str:
return await self._collapse_chain.arun(
input_documents=docs, callbacks=callbacks, **kwargs
)
_token_max = token_max or self.token_max
retries: int = 0
while num_tokens is not None and num_tokens > _token_max:
new_result_doc_list = split_list_of_docs(
result_docs, length_func, _token_max, **kwargs
)
result_docs = []
for docs in new_result_doc_list:
new_doc = await acollapse_docs(docs, _collapse_docs_func, **kwargs)
result_docs.append(new_doc)
num_tokens = length_func(result_docs, **kwargs)
retries += 1
if self.collapse_max_retries and retries == self.collapse_max_retries:
raise ValueError(
f"Exceed {self.collapse_max_retries} tries to \
collapse document to {_token_max} tokens."
)
return result_docs, {}
@property
def _chain_type(self) -> str:
return "reduce_documents_chain"