"""用于组合文档的链的基本接口。"""
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple, Type
from langchain_core.callbacks import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain_core.documents import Document
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables.config import RunnableConfig
from langchain_core.runnables.utils import create_model
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
from langchain.chains.base import Chain
DEFAULT_DOCUMENT_SEPARATOR = "\n\n"
DOCUMENTS_KEY = "context"
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template("{page_content}")
def _validate_prompt(prompt: BasePromptTemplate) -> None:
if DOCUMENTS_KEY not in prompt.input_variables:
raise ValueError(
f"Prompt must accept {DOCUMENTS_KEY} as an input variable. Received prompt "
f"with input variables: {prompt.input_variables}"
)
[docs]class BaseCombineDocumentsChain(Chain, ABC):
"""用于组合文档的基本接口。
这个链的子类处理以各种方式组合文档。这个基类的存在是为了在这些类型的链应该公开的接口中添加一些统一性。换句话说,它们期望与文档相关的输入键(默认为`input_documents`),然后还公开一种方法来计算来自文档的提示的长度(对于外部调用者来说很有用,以确定是否可以安全地将文档列表传递给这个链,或者是否会比上下文长度更长)。"""
input_key: str = "input_documents" #: :meta private:
output_key: str = "output_text" #: :meta private:
[docs] def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
return create_model(
"CombineDocumentsOutput",
**{self.output_key: (str, None)}, # type: ignore[call-overload]
)
@property
def input_keys(self) -> List[str]:
"""期望输入键。
:元数据 私有:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""返回输出关键字。
:元数据 私有:
"""
return [self.output_key]
[docs] def prompt_length(self, docs: List[Document], **kwargs: Any) -> Optional[int]:
"""给定传入的文档,返回提示长度。
调用方可以使用此函数来确定传入文档列表是否会超过某个提示长度。在尝试确保提示大小保持在某个上下文限制以下时,这很有用。
参数:
docs: List[Document],用于计算总提示长度的文档列表。
返回:
如果该方法不依赖于提示长度,则返回None,否则返回提示的长度(以标记为单位)。
"""
return None
[docs] @abstractmethod
def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]:
"""将文档合并成一个字符串。
参数:
docs: List[Document],要合并的文档
**kwargs: 用于合并文档的其他参数,通常是提示的其他输入。
返回:
返回的第一个元素是单个字符串输出。返回的第二个元素是要返回的其他键的字典。
"""
[docs] @abstractmethod
async def acombine_docs(
self, docs: List[Document], **kwargs: Any
) -> Tuple[str, dict]:
"""将文档合并成一个字符串。
参数:
docs: List[Document],要合并的文档
**kwargs: 用于合并文档的其他参数,通常是提示的其他输入。
返回:
返回的第一个元素是单个字符串输出。返回的第二个元素是要返回的其他键的字典。
"""
def _call(
self,
inputs: Dict[str, List[Document]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
"""准备输入,调用合并文档,准备输出。"""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
docs = inputs[self.input_key]
# Other keys are assumed to be needed for LLM prediction
other_keys = {k: v for k, v in inputs.items() if k != self.input_key}
output, extra_return_dict = self.combine_docs(
docs, callbacks=_run_manager.get_child(), **other_keys
)
extra_return_dict[self.output_key] = output
return extra_return_dict
async def _acall(
self,
inputs: Dict[str, List[Document]],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, str]:
"""准备输入,调用合并文档,准备输出。"""
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
docs = inputs[self.input_key]
# Other keys are assumed to be needed for LLM prediction
other_keys = {k: v for k, v in inputs.items() if k != self.input_key}
output, extra_return_dict = await self.acombine_docs(
docs, callbacks=_run_manager.get_child(), **other_keys
)
extra_return_dict[self.output_key] = output
return extra_return_dict
[docs]class AnalyzeDocumentChain(Chain):
"""链条,用于拆分文档,然后逐段分析。
该链条由TextSplitter和CombineDocumentsChain参数化。
该链条以单个文档作为输入,然后将其拆分为块,然后将这些块传递给CombineDocumentsChain。"""
input_key: str = "input_document" #: :meta private:
text_splitter: TextSplitter = Field(default_factory=RecursiveCharacterTextSplitter)
combine_docs_chain: BaseCombineDocumentsChain
@property
def input_keys(self) -> List[str]:
"""期望输入键。
:元数据 私有:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""返回输出关键字。
:元数据 私有:
"""
return self.combine_docs_chain.output_keys
[docs] def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
return self.combine_docs_chain.get_output_schema(config)
def _call(
self,
inputs: Dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
"""将文档分割成块,并传递给CombineDocumentsChain。"""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
document = inputs[self.input_key]
docs = self.text_splitter.create_documents([document])
# Other keys are assumed to be needed for LLM prediction
other_keys: Dict = {k: v for k, v in inputs.items() if k != self.input_key}
other_keys[self.combine_docs_chain.input_key] = docs
return self.combine_docs_chain(
other_keys, return_only_outputs=True, callbacks=_run_manager.get_child()
)