"""用于与向量数据库进行聊天的链。"""
from __future__ import annotations
import inspect
import warnings
from abc import abstractmethod
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from langchain_core._api import deprecated
from langchain_core.callbacks import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
Callbacks,
)
from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import BaseMessage
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import RunnableConfig
from langchain_core.vectorstores import VectorStore
from langchain.chains.base import Chain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT
from langchain.chains.llm import LLMChain
from langchain.chains.question_answering import load_qa_chain
# Depending on the memory type and configuration, the chat history format may differ.
# This needs to be consolidated.
CHAT_TURN_TYPE = Union[Tuple[str, str], BaseMessage]
_ROLE_MAP = {"human": "Human: ", "ai": "Assistant: "}
def _get_chat_history(chat_history: List[CHAT_TURN_TYPE]) -> str:
buffer = ""
for dialogue_turn in chat_history:
if isinstance(dialogue_turn, BaseMessage):
role_prefix = _ROLE_MAP.get(dialogue_turn.type, f"{dialogue_turn.type}: ")
buffer += f"\n{role_prefix}{dialogue_turn.content}"
elif isinstance(dialogue_turn, tuple):
human = "Human: " + dialogue_turn[0]
ai = "Assistant: " + dialogue_turn[1]
buffer += "\n" + "\n".join([human, ai])
else:
raise ValueError(
f"Unsupported chat history format: {type(dialogue_turn)}."
f" Full chat history: {chat_history} "
)
return buffer
[docs]class BaseConversationalRetrievalChain(Chain):
"""用于与索引进行聊天的链条。"""
combine_docs_chain: BaseCombineDocumentsChain
"""用于组合检索到的任何文档的链条。"""
question_generator: LLMChain
"""用于生成新问题以便检索的链条。
此链条将接收当前问题(使用变量`question`)
和任何聊天历史记录(使用变量`chat_history`),并生成
一个新的独立问题,以便稍后使用。"""
output_key: str = "answer"
"""输出键用于返回此链中的最终答案。"""
rephrase_question: bool = True
"""是否将新生成的问题传递给combine_docs_chain。
如果为True,则会传递新生成的问题。
如果为False,则仅将新生成的问题用于检索,并将原始问题传递给combine_docs_chain。"""
return_source_documents: bool = False
"""将检索到的源文件作为最终结果的一部分返回。"""
return_generated_question: bool = False
"""将生成的问题作为最终结果的一部分返回。"""
get_chat_history: Optional[Callable[[List[CHAT_TURN_TYPE]], str]] = None
"""一个可选的函数,用于获取聊天记录的字符串。
如果未提供参数,则使用默认值。"""
response_if_no_docs_found: Optional[str]
"""如果指定了,如果找不到问题的文档,链将返回一个固定的响应。"""
class Config:
"""这个pydantic对象的配置。"""
extra = Extra.forbid
arbitrary_types_allowed = True
allow_population_by_field_name = True
@property
def input_keys(self) -> List[str]:
"""输入键。"""
return ["question", "chat_history"]
@property
def output_keys(self) -> List[str]:
"""返回输出键。
:元数据 私有:
"""
_output_keys = [self.output_key]
if self.return_source_documents:
_output_keys = _output_keys + ["source_documents"]
if self.return_generated_question:
_output_keys = _output_keys + ["generated_question"]
return _output_keys
@abstractmethod
def _get_docs(
self,
question: str,
inputs: Dict[str, Any],
*,
run_manager: CallbackManagerForChainRun,
) -> List[Document]:
"""获取文档。"""
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
question = inputs["question"]
get_chat_history = self.get_chat_history or _get_chat_history
chat_history_str = get_chat_history(inputs["chat_history"])
if chat_history_str:
callbacks = _run_manager.get_child()
new_question = self.question_generator.run(
question=question, chat_history=chat_history_str, callbacks=callbacks
)
else:
new_question = question
accepts_run_manager = (
"run_manager" in inspect.signature(self._get_docs).parameters
)
if accepts_run_manager:
docs = self._get_docs(new_question, inputs, run_manager=_run_manager)
else:
docs = self._get_docs(new_question, inputs) # type: ignore[call-arg]
output: Dict[str, Any] = {}
if self.response_if_no_docs_found is not None and len(docs) == 0:
output[self.output_key] = self.response_if_no_docs_found
else:
new_inputs = inputs.copy()
if self.rephrase_question:
new_inputs["question"] = new_question
new_inputs["chat_history"] = chat_history_str
answer = self.combine_docs_chain.run(
input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs
)
output[self.output_key] = answer
if self.return_source_documents:
output["source_documents"] = docs
if self.return_generated_question:
output["generated_question"] = new_question
return output
@abstractmethod
async def _aget_docs(
self,
question: str,
inputs: Dict[str, Any],
*,
run_manager: AsyncCallbackManagerForChainRun,
) -> List[Document]:
"""获取文档。"""
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
question = inputs["question"]
get_chat_history = self.get_chat_history or _get_chat_history
chat_history_str = get_chat_history(inputs["chat_history"])
if chat_history_str:
callbacks = _run_manager.get_child()
new_question = await self.question_generator.arun(
question=question, chat_history=chat_history_str, callbacks=callbacks
)
else:
new_question = question
accepts_run_manager = (
"run_manager" in inspect.signature(self._aget_docs).parameters
)
if accepts_run_manager:
docs = await self._aget_docs(new_question, inputs, run_manager=_run_manager)
else:
docs = await self._aget_docs(new_question, inputs) # type: ignore[call-arg]
output: Dict[str, Any] = {}
if self.response_if_no_docs_found is not None and len(docs) == 0:
output[self.output_key] = self.response_if_no_docs_found
else:
new_inputs = inputs.copy()
if self.rephrase_question:
new_inputs["question"] = new_question
new_inputs["chat_history"] = chat_history_str
answer = await self.combine_docs_chain.arun(
input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs
)
output[self.output_key] = answer
if self.return_source_documents:
output["source_documents"] = docs
if self.return_generated_question:
output["generated_question"] = new_question
return output
[docs] def save(self, file_path: Union[Path, str]) -> None:
if self.get_chat_history:
raise ValueError("Chain not saveable when `get_chat_history` is not None.")
super().save(file_path)
[docs]@deprecated(
since="0.1.17",
alternative=(
"create_history_aware_retriever together with create_retrieval_chain "
"(see example in docstring)"
),
removal="0.3.0",
)
class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
"""基于检索文档的对话链。
该类已被弃用。请参见下面一个使用`create_retrieval_chain`的示例实现。更多示例可在 https://python.langchain.com/docs/use_cases/question_answering/chat_history 找到。
.. code-block:: python
from langchain.chains import (
create_history_aware_retriever,
create_retrieval_chain,
)
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_openai import ChatOpenAI
retriever = ... # 您的检索器
llm = ChatOpenAI()
# 上下文化问题
contextualize_q_system_prompt = (
"给定一个聊天历史和最新的用户问题,可能引用聊天历史中的上下文,"
"制定一个独立的问题,该问题可以在没有聊天历史的情况下理解。"
"如果需要,重新构造问题,否则原样返回。"
)
contextualize_q_prompt = ChatPromptTemplate.from_messages(
[
("system", contextualize_q_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)
history_aware_retriever = create_history_aware_retriever(
llm, retriever, contextualize_q_prompt
)
# 回答问题
qa_system_prompt = (
"您是一个用于问答任务的助手。使用以下检索到的上下文来回答问题。"
"如果不知道答案,只需说不知道。最多使用三句话,答案要简洁。"
"
"
"{context}"
)
qa_prompt = ChatPromptTemplate.from_messages(
[
("system", qa_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)
# 下面我们使用create_stuff_documents_chain将所有检索到的上下文传递给LLM。请注意,我们也可以使用StuffDocumentsChain和其他BaseCombineDocumentsChain的实例。
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
rag_chain = create_retrieval_chain(
history_aware_retriever, question_answer_chain
)
# 用法:
chat_history = [] # 在此处收集聊天历史(一系列消息)
rag_chain.invoke({"input": query, "chat_history": chat_history})
该链接受聊天历史(消息列表)和新问题,然后返回该问题的答案。
该链的算法包括三个部分:
1. 使用聊天历史和新问题创建一个“独立问题”。
这样做是为了将该问题传递到检索步骤以获取相关文档。如果只传入新问题,则可能缺少相关上下文。
如果将整个对话传递到检索中,可能会有不必要的信息,会分散检索的注意力。
2. 将这个新问题传递给检索器,并返回相关文档。
3. 将检索到的文档与新问题(默认行为)或原始问题和聊天历史一起传递给LLM,以生成最终响应。
示例:
.. code-block:: python
from langchain.chains import (
StuffDocumentsChain, LLMChain, ConversationalRetrievalChain
)
from langchain_core.prompts import PromptTemplate
from langchain_community.llms import OpenAI
combine_docs_chain = StuffDocumentsChain(...)
vectorstore = ...
retriever = vectorstore.as_retriever()
# 这控制着如何生成独立问题。
# 应该将`chat_history`和`question`作为输入变量。
template = (
"将聊天历史和后续问题合并为一个独立问题。聊天历史:{chat_history}"
"后续问题:{question}"
)
prompt = PromptTemplate.from_template(template)
llm = OpenAI()
question_generator_chain = LLMChain(llm=llm, prompt=prompt)
chain = ConversationalRetrievalChain(
combine_docs_chain=combine_docs_chain,
retriever=retriever,
question_generator=question_generator_chain,
)
"""
retriever: BaseRetriever
"""用于获取文档的检索器。"""
max_tokens_limit: Optional[int] = None
"""如果设置了,强制返回的文档数量小于此限制。只有在`combine_docs_chain`的类型为StuffDocumentsChain时才会执行此限制。"""
def _reduce_tokens_below_limit(self, docs: List[Document]) -> List[Document]:
num_docs = len(docs)
if self.max_tokens_limit and isinstance(
self.combine_docs_chain, StuffDocumentsChain
):
tokens = [
self.combine_docs_chain.llm_chain._get_num_tokens(doc.page_content)
for doc in docs
]
token_count = sum(tokens[:num_docs])
while token_count > self.max_tokens_limit:
num_docs -= 1
token_count -= tokens[num_docs]
return docs[:num_docs]
def _get_docs(
self,
question: str,
inputs: Dict[str, Any],
*,
run_manager: CallbackManagerForChainRun,
) -> List[Document]:
"""获取文档。"""
docs = self.retriever.invoke(
question, config={"callbacks": run_manager.get_child()}
)
return self._reduce_tokens_below_limit(docs)
async def _aget_docs(
self,
question: str,
inputs: Dict[str, Any],
*,
run_manager: AsyncCallbackManagerForChainRun,
) -> List[Document]:
"""获取文档。"""
docs = await self.retriever.ainvoke(
question, config={"callbacks": run_manager.get_child()}
)
return self._reduce_tokens_below_limit(docs)
[docs] @classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
retriever: BaseRetriever,
condense_question_prompt: BasePromptTemplate = CONDENSE_QUESTION_PROMPT,
chain_type: str = "stuff",
verbose: bool = False,
condense_question_llm: Optional[BaseLanguageModel] = None,
combine_docs_chain_kwargs: Optional[Dict] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> BaseConversationalRetrievalChain:
"""从LLM和检索器加载链的便利方法。
这提供了一些逻辑来创建`question_generator`链以及`combine_docs_chain`。
参数:
llm: 在此链的每个部分中使用的默认语言模型
(例如,在问题生成和回答中都使用)
retriever: 用于获取相关文档的检索器。
condense_question_prompt: 用于压缩聊天历史和新问题为独立问题的提示。
chain_type: 用于创建`combine_docs_chain`的链类型,将
被发送到`load_qa_chain`。
verbose: 用于记录到标准输出的详细信息标志。
condense_question_llm: 用于将聊天历史和新问题压缩为独立问题的语言模型。如果未提供,则默认为`llm`。
combine_docs_chain_kwargs: 在构建`combine_docs_chain`时传递给`load_qa_chain`的参数。
callbacks: 传递给所有子链的回调。
**kwargs: 初始化ConversationalRetrievalChain时传递的其他参数。
"""
combine_docs_chain_kwargs = combine_docs_chain_kwargs or {}
doc_chain = load_qa_chain(
llm,
chain_type=chain_type,
verbose=verbose,
callbacks=callbacks,
**combine_docs_chain_kwargs,
)
_llm = condense_question_llm or llm
condense_question_chain = LLMChain(
llm=_llm,
prompt=condense_question_prompt,
verbose=verbose,
callbacks=callbacks,
)
return cls(
retriever=retriever,
combine_docs_chain=doc_chain,
question_generator=condense_question_chain,
callbacks=callbacks,
**kwargs,
)
[docs]class ChatVectorDBChain(BaseConversationalRetrievalChain):
"""用于与向量数据库进行聊天的链。"""
vectorstore: VectorStore = Field(alias="vectorstore")
top_k_docs_for_context: int = 4
search_kwargs: dict = Field(default_factory=dict)
@property
def _chain_type(self) -> str:
return "chat-vector-db"
@root_validator()
def raise_deprecation(cls, values: Dict) -> Dict:
warnings.warn(
"`ChatVectorDBChain` is deprecated - "
"please use `from langchain.chains import ConversationalRetrievalChain`"
)
return values
def _get_docs(
self,
question: str,
inputs: Dict[str, Any],
*,
run_manager: CallbackManagerForChainRun,
) -> List[Document]:
"""获取文档。"""
vectordbkwargs = inputs.get("vectordbkwargs", {})
full_kwargs = {**self.search_kwargs, **vectordbkwargs}
return self.vectorstore.similarity_search(
question, k=self.top_k_docs_for_context, **full_kwargs
)
async def _aget_docs(
self,
question: str,
inputs: Dict[str, Any],
*,
run_manager: AsyncCallbackManagerForChainRun,
) -> List[Document]:
"""获取文档。"""
raise NotImplementedError("ChatVectorDBChain does not support async")
[docs] @classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
vectorstore: VectorStore,
condense_question_prompt: BasePromptTemplate = CONDENSE_QUESTION_PROMPT,
chain_type: str = "stuff",
combine_docs_chain_kwargs: Optional[Dict] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> BaseConversationalRetrievalChain:
"""从LLM加载链。"""
combine_docs_chain_kwargs = combine_docs_chain_kwargs or {}
doc_chain = load_qa_chain(
llm,
chain_type=chain_type,
callbacks=callbacks,
**combine_docs_chain_kwargs,
)
condense_question_chain = LLMChain(
llm=llm, prompt=condense_question_prompt, callbacks=callbacks
)
return cls(
vectorstore=vectorstore,
combine_docs_chain=doc_chain,
question_generator=condense_question_chain,
callbacks=callbacks,
**kwargs,
)