Source code for langchain.chains.conversational_retrieval.base

"""用于与向量数据库进行聊天的链。"""
from __future__ import annotations

import inspect
import warnings
from abc import abstractmethod
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

from langchain_core._api import deprecated
from langchain_core.callbacks import (
    AsyncCallbackManagerForChainRun,
    CallbackManagerForChainRun,
    Callbacks,
)
from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import BaseMessage
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import RunnableConfig
from langchain_core.vectorstores import VectorStore

from langchain.chains.base import Chain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT
from langchain.chains.llm import LLMChain
from langchain.chains.question_answering import load_qa_chain

# Depending on the memory type and configuration, the chat history format may differ.
# This needs to be consolidated.
CHAT_TURN_TYPE = Union[Tuple[str, str], BaseMessage]


_ROLE_MAP = {"human": "Human: ", "ai": "Assistant: "}


def _get_chat_history(chat_history: List[CHAT_TURN_TYPE]) -> str:
    buffer = ""
    for dialogue_turn in chat_history:
        if isinstance(dialogue_turn, BaseMessage):
            role_prefix = _ROLE_MAP.get(dialogue_turn.type, f"{dialogue_turn.type}: ")
            buffer += f"\n{role_prefix}{dialogue_turn.content}"
        elif isinstance(dialogue_turn, tuple):
            human = "Human: " + dialogue_turn[0]
            ai = "Assistant: " + dialogue_turn[1]
            buffer += "\n" + "\n".join([human, ai])
        else:
            raise ValueError(
                f"Unsupported chat history format: {type(dialogue_turn)}."
                f" Full chat history: {chat_history} "
            )
    return buffer


[docs]class InputType(BaseModel): """用于对话检索链的输入类型。""" question: str """问题待解答。""" chat_history: List[CHAT_TURN_TYPE] = Field(default_factory=list) """用于检索的聊天记录。"""
[docs]class BaseConversationalRetrievalChain(Chain): """用于与索引进行聊天的链条。""" combine_docs_chain: BaseCombineDocumentsChain """用于组合检索到的任何文档的链条。""" question_generator: LLMChain """用于生成新问题以便检索的链条。 此链条将接收当前问题(使用变量`question`) 和任何聊天历史记录(使用变量`chat_history`),并生成 一个新的独立问题,以便稍后使用。""" output_key: str = "answer" """输出键用于返回此链中的最终答案。""" rephrase_question: bool = True """是否将新生成的问题传递给combine_docs_chain。 如果为True,则会传递新生成的问题。 如果为False,则仅将新生成的问题用于检索,并将原始问题传递给combine_docs_chain。""" return_source_documents: bool = False """将检索到的源文件作为最终结果的一部分返回。""" return_generated_question: bool = False """将生成的问题作为最终结果的一部分返回。""" get_chat_history: Optional[Callable[[List[CHAT_TURN_TYPE]], str]] = None """一个可选的函数,用于获取聊天记录的字符串。 如果未提供参数,则使用默认值。""" response_if_no_docs_found: Optional[str] """如果指定了,如果找不到问题的文档,链将返回一个固定的响应。""" class Config: """这个pydantic对象的配置。""" extra = Extra.forbid arbitrary_types_allowed = True allow_population_by_field_name = True @property def input_keys(self) -> List[str]: """输入键。""" return ["question", "chat_history"]
[docs] def get_input_schema( self, config: Optional[RunnableConfig] = None ) -> Type[BaseModel]: return InputType
@property def output_keys(self) -> List[str]: """返回输出键。 :元数据 私有: """ _output_keys = [self.output_key] if self.return_source_documents: _output_keys = _output_keys + ["source_documents"] if self.return_generated_question: _output_keys = _output_keys + ["generated_question"] return _output_keys @abstractmethod def _get_docs( self, question: str, inputs: Dict[str, Any], *, run_manager: CallbackManagerForChainRun, ) -> List[Document]: """获取文档。""" def _call( self, inputs: Dict[str, Any], run_manager: Optional[CallbackManagerForChainRun] = None, ) -> Dict[str, Any]: _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() question = inputs["question"] get_chat_history = self.get_chat_history or _get_chat_history chat_history_str = get_chat_history(inputs["chat_history"]) if chat_history_str: callbacks = _run_manager.get_child() new_question = self.question_generator.run( question=question, chat_history=chat_history_str, callbacks=callbacks ) else: new_question = question accepts_run_manager = ( "run_manager" in inspect.signature(self._get_docs).parameters ) if accepts_run_manager: docs = self._get_docs(new_question, inputs, run_manager=_run_manager) else: docs = self._get_docs(new_question, inputs) # type: ignore[call-arg] output: Dict[str, Any] = {} if self.response_if_no_docs_found is not None and len(docs) == 0: output[self.output_key] = self.response_if_no_docs_found else: new_inputs = inputs.copy() if self.rephrase_question: new_inputs["question"] = new_question new_inputs["chat_history"] = chat_history_str answer = self.combine_docs_chain.run( input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs ) output[self.output_key] = answer if self.return_source_documents: output["source_documents"] = docs if self.return_generated_question: output["generated_question"] = new_question return output @abstractmethod async def _aget_docs( self, question: str, inputs: Dict[str, Any], *, run_manager: AsyncCallbackManagerForChainRun, ) -> List[Document]: """获取文档。""" async def _acall( self, inputs: Dict[str, Any], run_manager: Optional[AsyncCallbackManagerForChainRun] = None, ) -> Dict[str, Any]: _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() question = inputs["question"] get_chat_history = self.get_chat_history or _get_chat_history chat_history_str = get_chat_history(inputs["chat_history"]) if chat_history_str: callbacks = _run_manager.get_child() new_question = await self.question_generator.arun( question=question, chat_history=chat_history_str, callbacks=callbacks ) else: new_question = question accepts_run_manager = ( "run_manager" in inspect.signature(self._aget_docs).parameters ) if accepts_run_manager: docs = await self._aget_docs(new_question, inputs, run_manager=_run_manager) else: docs = await self._aget_docs(new_question, inputs) # type: ignore[call-arg] output: Dict[str, Any] = {} if self.response_if_no_docs_found is not None and len(docs) == 0: output[self.output_key] = self.response_if_no_docs_found else: new_inputs = inputs.copy() if self.rephrase_question: new_inputs["question"] = new_question new_inputs["chat_history"] = chat_history_str answer = await self.combine_docs_chain.arun( input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs ) output[self.output_key] = answer if self.return_source_documents: output["source_documents"] = docs if self.return_generated_question: output["generated_question"] = new_question return output
[docs] def save(self, file_path: Union[Path, str]) -> None: if self.get_chat_history: raise ValueError("Chain not saveable when `get_chat_history` is not None.") super().save(file_path)
[docs]@deprecated( since="0.1.17", alternative=( "create_history_aware_retriever together with create_retrieval_chain " "(see example in docstring)" ), removal="0.3.0", ) class ConversationalRetrievalChain(BaseConversationalRetrievalChain): """基于检索文档的对话链。 该类已被弃用。请参见下面一个使用`create_retrieval_chain`的示例实现。更多示例可在 https://python.langchain.com/docs/use_cases/question_answering/chat_history 找到。 .. code-block:: python from langchain.chains import ( create_history_aware_retriever, create_retrieval_chain, ) from langchain.chains.combine_documents import create_stuff_documents_chain from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_openai import ChatOpenAI retriever = ... # 您的检索器 llm = ChatOpenAI() # 上下文化问题 contextualize_q_system_prompt = ( "给定一个聊天历史和最新的用户问题,可能引用聊天历史中的上下文," "制定一个独立的问题,该问题可以在没有聊天历史的情况下理解。" "如果需要,重新构造问题,否则原样返回。" ) contextualize_q_prompt = ChatPromptTemplate.from_messages( [ ("system", contextualize_q_system_prompt), MessagesPlaceholder("chat_history"), ("human", "{input}"), ] ) history_aware_retriever = create_history_aware_retriever( llm, retriever, contextualize_q_prompt ) # 回答问题 qa_system_prompt = ( "您是一个用于问答任务的助手。使用以下检索到的上下文来回答问题。" "如果不知道答案,只需说不知道。最多使用三句话,答案要简洁。" " " "{context}" ) qa_prompt = ChatPromptTemplate.from_messages( [ ("system", qa_system_prompt), MessagesPlaceholder("chat_history"), ("human", "{input}"), ] ) # 下面我们使用create_stuff_documents_chain将所有检索到的上下文传递给LLM。请注意,我们也可以使用StuffDocumentsChain和其他BaseCombineDocumentsChain的实例。 question_answer_chain = create_stuff_documents_chain(llm, qa_prompt) rag_chain = create_retrieval_chain( history_aware_retriever, question_answer_chain ) # 用法: chat_history = [] # 在此处收集聊天历史(一系列消息) rag_chain.invoke({"input": query, "chat_history": chat_history}) 该链接受聊天历史(消息列表)和新问题,然后返回该问题的答案。 该链的算法包括三个部分: 1. 使用聊天历史和新问题创建一个“独立问题”。 这样做是为了将该问题传递到检索步骤以获取相关文档。如果只传入新问题,则可能缺少相关上下文。 如果将整个对话传递到检索中,可能会有不必要的信息,会分散检索的注意力。 2. 将这个新问题传递给检索器,并返回相关文档。 3. 将检索到的文档与新问题(默认行为)或原始问题和聊天历史一起传递给LLM,以生成最终响应。 示例: .. code-block:: python from langchain.chains import ( StuffDocumentsChain, LLMChain, ConversationalRetrievalChain ) from langchain_core.prompts import PromptTemplate from langchain_community.llms import OpenAI combine_docs_chain = StuffDocumentsChain(...) vectorstore = ... retriever = vectorstore.as_retriever() # 这控制着如何生成独立问题。 # 应该将`chat_history`和`question`作为输入变量。 template = ( "将聊天历史和后续问题合并为一个独立问题。聊天历史:{chat_history}" "后续问题:{question}" ) prompt = PromptTemplate.from_template(template) llm = OpenAI() question_generator_chain = LLMChain(llm=llm, prompt=prompt) chain = ConversationalRetrievalChain( combine_docs_chain=combine_docs_chain, retriever=retriever, question_generator=question_generator_chain, ) """ retriever: BaseRetriever """用于获取文档的检索器。""" max_tokens_limit: Optional[int] = None """如果设置了,强制返回的文档数量小于此限制。只有在`combine_docs_chain`的类型为StuffDocumentsChain时才会执行此限制。""" def _reduce_tokens_below_limit(self, docs: List[Document]) -> List[Document]: num_docs = len(docs) if self.max_tokens_limit and isinstance( self.combine_docs_chain, StuffDocumentsChain ): tokens = [ self.combine_docs_chain.llm_chain._get_num_tokens(doc.page_content) for doc in docs ] token_count = sum(tokens[:num_docs]) while token_count > self.max_tokens_limit: num_docs -= 1 token_count -= tokens[num_docs] return docs[:num_docs] def _get_docs( self, question: str, inputs: Dict[str, Any], *, run_manager: CallbackManagerForChainRun, ) -> List[Document]: """获取文档。""" docs = self.retriever.invoke( question, config={"callbacks": run_manager.get_child()} ) return self._reduce_tokens_below_limit(docs) async def _aget_docs( self, question: str, inputs: Dict[str, Any], *, run_manager: AsyncCallbackManagerForChainRun, ) -> List[Document]: """获取文档。""" docs = await self.retriever.ainvoke( question, config={"callbacks": run_manager.get_child()} ) return self._reduce_tokens_below_limit(docs)
[docs] @classmethod def from_llm( cls, llm: BaseLanguageModel, retriever: BaseRetriever, condense_question_prompt: BasePromptTemplate = CONDENSE_QUESTION_PROMPT, chain_type: str = "stuff", verbose: bool = False, condense_question_llm: Optional[BaseLanguageModel] = None, combine_docs_chain_kwargs: Optional[Dict] = None, callbacks: Callbacks = None, **kwargs: Any, ) -> BaseConversationalRetrievalChain: """从LLM和检索器加载链的便利方法。 这提供了一些逻辑来创建`question_generator`链以及`combine_docs_chain`。 参数: llm: 在此链的每个部分中使用的默认语言模型 (例如,在问题生成和回答中都使用) retriever: 用于获取相关文档的检索器。 condense_question_prompt: 用于压缩聊天历史和新问题为独立问题的提示。 chain_type: 用于创建`combine_docs_chain`的链类型,将 被发送到`load_qa_chain`。 verbose: 用于记录到标准输出的详细信息标志。 condense_question_llm: 用于将聊天历史和新问题压缩为独立问题的语言模型。如果未提供,则默认为`llm`。 combine_docs_chain_kwargs: 在构建`combine_docs_chain`时传递给`load_qa_chain`的参数。 callbacks: 传递给所有子链的回调。 **kwargs: 初始化ConversationalRetrievalChain时传递的其他参数。 """ combine_docs_chain_kwargs = combine_docs_chain_kwargs or {} doc_chain = load_qa_chain( llm, chain_type=chain_type, verbose=verbose, callbacks=callbacks, **combine_docs_chain_kwargs, ) _llm = condense_question_llm or llm condense_question_chain = LLMChain( llm=_llm, prompt=condense_question_prompt, verbose=verbose, callbacks=callbacks, ) return cls( retriever=retriever, combine_docs_chain=doc_chain, question_generator=condense_question_chain, callbacks=callbacks, **kwargs, )
[docs]class ChatVectorDBChain(BaseConversationalRetrievalChain): """用于与向量数据库进行聊天的链。""" vectorstore: VectorStore = Field(alias="vectorstore") top_k_docs_for_context: int = 4 search_kwargs: dict = Field(default_factory=dict) @property def _chain_type(self) -> str: return "chat-vector-db" @root_validator() def raise_deprecation(cls, values: Dict) -> Dict: warnings.warn( "`ChatVectorDBChain` is deprecated - " "please use `from langchain.chains import ConversationalRetrievalChain`" ) return values def _get_docs( self, question: str, inputs: Dict[str, Any], *, run_manager: CallbackManagerForChainRun, ) -> List[Document]: """获取文档。""" vectordbkwargs = inputs.get("vectordbkwargs", {}) full_kwargs = {**self.search_kwargs, **vectordbkwargs} return self.vectorstore.similarity_search( question, k=self.top_k_docs_for_context, **full_kwargs ) async def _aget_docs( self, question: str, inputs: Dict[str, Any], *, run_manager: AsyncCallbackManagerForChainRun, ) -> List[Document]: """获取文档。""" raise NotImplementedError("ChatVectorDBChain does not support async")
[docs] @classmethod def from_llm( cls, llm: BaseLanguageModel, vectorstore: VectorStore, condense_question_prompt: BasePromptTemplate = CONDENSE_QUESTION_PROMPT, chain_type: str = "stuff", combine_docs_chain_kwargs: Optional[Dict] = None, callbacks: Callbacks = None, **kwargs: Any, ) -> BaseConversationalRetrievalChain: """从LLM加载链。""" combine_docs_chain_kwargs = combine_docs_chain_kwargs or {} doc_chain = load_qa_chain( llm, chain_type=chain_type, callbacks=callbacks, **combine_docs_chain_kwargs, ) condense_question_chain = LLMChain( llm=llm, prompt=condense_question_prompt, callbacks=callbacks ) return cls( vectorstore=vectorstore, combine_docs_chain=doc_chain, question_generator=condense_question_chain, callbacks=callbacks, **kwargs, )