"""将文档通过填充到上下文中进行组合的链。"""
from typing import Any, Dict, List, Optional, Tuple
from langchain_core.callbacks import Callbacks
from langchain_core.documents import Document
from langchain_core.language_models import LanguageModelLike
from langchain_core.output_parsers import BaseOutputParser, StrOutputParser
from langchain_core.prompts import BasePromptTemplate, format_document
from langchain_core.pydantic_v1 import Extra, Field, root_validator
from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain.chains.combine_documents.base import (
DEFAULT_DOCUMENT_PROMPT,
DEFAULT_DOCUMENT_SEPARATOR,
DOCUMENTS_KEY,
BaseCombineDocumentsChain,
_validate_prompt,
)
from langchain.chains.llm import LLMChain
[docs]def create_stuff_documents_chain(
llm: LanguageModelLike,
prompt: BasePromptTemplate,
*,
output_parser: Optional[BaseOutputParser] = None,
document_prompt: Optional[BasePromptTemplate] = None,
document_separator: str = DEFAULT_DOCUMENT_SEPARATOR,
) -> Runnable[Dict[str, Any], Any]:
"""创建一个链条,用于将文档列表传递给模型。
参数:
llm: 语言模型。
prompt: 提示模板。必须包含输入变量“context”,该变量将用于传递格式化后的文档。
output_parser: 输出解析器。默认为StrOutputParser。
document_prompt: 用于将每个文档格式化为字符串的提示。输入变量可以是“page_content”或所有文档中都存在的任何元数据键。 “page_content”将自动检索`Document.page_content`,所有其他输入变量将从`Document.metadata`字典中自动检索。默认为仅包含`Document.page_content`的提示。
document_separator: 用于在格式化的文档字符串之间使用的字符串分隔符。
返回:
一个LCEL Runnable。输入是一个字典,必须有一个映射到List[Document]的“context”键,以及提示中预期的任何其他输入变量。Runnable返回类型取决于使用的output_parser。
示例:
.. code-block:: python
# pip install -U langchain langchain-community
from langchain_community.chat_models import ChatOpenAI
from langchain_core.documents import Document
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains.combine_documents import create_stuff_documents_chain
prompt = ChatPromptTemplate.from_messages(
[("system", "What are everyone's favorite colors:\n\n{context}")]
)
llm = ChatOpenAI(model="gpt-3.5-turbo")
chain = create_stuff_documents_chain(llm, prompt)
docs = [
Document(page_content="Jesse loves red but not yellow"),
Document(page_content = "Jamal loves green but not as much as he loves orange")
]
chain.invoke({"context": docs})
""" # noqa: E501
_validate_prompt(prompt)
_document_prompt = document_prompt or DEFAULT_DOCUMENT_PROMPT
_output_parser = output_parser or StrOutputParser()
def format_docs(inputs: dict) -> str:
return document_separator.join(
format_document(doc, _document_prompt) for doc in inputs[DOCUMENTS_KEY]
)
return (
RunnablePassthrough.assign(**{DOCUMENTS_KEY: format_docs}).with_config(
run_name="format_inputs"
)
| prompt
| llm
| _output_parser
).with_config(run_name="stuff_documents_chain")
[docs]class StuffDocumentsChain(BaseCombineDocumentsChain):
"""将文档通过填充内容组合的链。
该链接收一个文档列表,首先将它们组合成一个字符串。
它通过使用`document_prompt`将每个文档格式化为一个字符串,然后用`document_separator`将它们连接在一起来实现这一点。然后将该新字符串添加到由`document_variable_name`设置的输入中。
然后将这些输入传递给`llm_chain`。
示例:
.. code-block:: python
from langchain.chains import StuffDocumentsChain, LLMChain
from langchain_core.prompts import PromptTemplate
from langchain_community.llms import OpenAI
# 这控制每个文档将如何被格式化。具体来说,
# 它将被传递给`format_document` - 有关更多详细信息,请参见该函数。
document_prompt = PromptTemplate(
input_variables=["page_content"],
template="{page_content}"
)
document_variable_name = "context"
llm = OpenAI()
# 这里的提示应该将`document_variable_name`作为输入变量
prompt = PromptTemplate.from_template(
"总结这段内容:{context}"
)
llm_chain = LLMChain(llm=llm, prompt=prompt)
chain = StuffDocumentsChain(
llm_chain=llm_chain,
document_prompt=document_prompt,
document_variable_name=document_variable_name
)"""
llm_chain: LLMChain
"""LLM链,通过格式化的文档字符串调用,以及任何其他输入。"""
document_prompt: BasePromptTemplate = Field(
default_factory=lambda: DEFAULT_DOCUMENT_PROMPT
)
"""传递给`format_document`的每个文档的格式提示。"""
document_variable_name: str
"""变量名在llm_chain中放置文档。
如果在llm_chain中只有一个变量,则无需提供。"""
document_separator: str = "\n\n"
"""用于连接格式化文档的字符串"""
class Config:
"""这个pydantic对象的配置。"""
extra = Extra.forbid
arbitrary_types_allowed = True
@root_validator(pre=True)
def get_default_document_variable_name(cls, values: Dict) -> Dict:
"""获取默认文档变量名称,如果未提供。
如果在llm_chain.prompt中只有一个变量,
我们可以推断应该使用这个变量名传递格式化后的文档。
"""
llm_chain_variables = values["llm_chain"].prompt.input_variables
if "document_variable_name" not in values:
if len(llm_chain_variables) == 1:
values["document_variable_name"] = llm_chain_variables[0]
else:
raise ValueError(
"document_variable_name must be provided if there are "
"multiple llm_chain_variables"
)
else:
if values["document_variable_name"] not in llm_chain_variables:
raise ValueError(
f"document_variable_name {values['document_variable_name']} was "
f"not found in llm_chain input_variables: {llm_chain_variables}"
)
return values
@property
def input_keys(self) -> List[str]:
extra_keys = [
k for k in self.llm_chain.input_keys if k != self.document_variable_name
]
return super().input_keys + extra_keys
def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict:
"""从kwargs和docs构建输入。
将所有文档格式化,然后连接成一个名为`self.document_variable_name`的输入。还可以从**kwargs中提取任何额外的变量。
参数:
docs:要格式化并连接成单个输入的文档列表
**kwargs:要链接的额外输入,将从这里提取任何其他所需的参数。
返回:
输入到LLMChain的字典
"""
# Format each document according to the prompt
doc_strings = [format_document(doc, self.document_prompt) for doc in docs]
# Join the documents together to put them in the prompt.
inputs = {
k: v
for k, v in kwargs.items()
if k in self.llm_chain.prompt.input_variables
}
inputs[self.document_variable_name] = self.document_separator.join(doc_strings)
return inputs
[docs] def prompt_length(self, docs: List[Document], **kwargs: Any) -> Optional[int]:
"""给定传入的文档,返回提示长度。
调用者可以使用此功能来确定传入文档列表是否会超过某个提示长度。在尝试确保提示大小保持在某个上下文限制以下时,这很有用。
参数:
docs: List[Document],用于计算总提示长度的文档列表。
返回:
如果该方法不依赖于提示长度,则返回None,否则返回提示的长度(以标记为单位)。
"""
inputs = self._get_inputs(docs, **kwargs)
prompt = self.llm_chain.prompt.format(**inputs)
return self.llm_chain._get_num_tokens(prompt)
[docs] def combine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]:
"""将所有文档内容整合到一个提示中,并传递给LLM。
参数:
docs:需要合并成一个变量的文档列表
callbacks:可选的回调函数以传递
**kwargs:用于获取输入到LLMChain的其他参数
返回值:
返回的第一个元素是单个字符串输出。返回的第二个元素是要返回的其他键的字典。
"""
inputs = self._get_inputs(docs, **kwargs)
# Call predict on the LLM.
return self.llm_chain.predict(callbacks=callbacks, **inputs), {}
[docs] async def acombine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]:
"""将所有文档异步地合并到一个提示中,并传递给LLM。
参数:
docs:要合并到一个变量中的文档列表
callbacks:要传递的可选回调函数
**kwargs:用于获取输入到LLMChain的其他参数
返回:
返回的第一个元素是单个字符串输出。返回的第二个元素是要返回的其他键的字典。
"""
inputs = self._get_inputs(docs, **kwargs)
# Call predict on the LLM.
return await self.llm_chain.apredict(callbacks=callbacks, **inputs), {}
@property
def _chain_type(self) -> str:
return "stuff_documents_chain"