Source code for langchain.chains.retrieval_qa.base

"""用于针对向量数据库进行问答的链。"""
from __future__ import annotations

import inspect
import warnings
from abc import abstractmethod
from typing import Any, Dict, List, Optional

from langchain_core._api import deprecated
from langchain_core.callbacks import (
    AsyncCallbackManagerForChainRun,
    CallbackManagerForChainRun,
    Callbacks,
)
from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import PromptTemplate
from langchain_core.pydantic_v1 import Extra, Field, root_validator
from langchain_core.retrievers import BaseRetriever
from langchain_core.vectorstores import VectorStore

from langchain.chains.base import Chain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.chains.question_answering import load_qa_chain
from langchain.chains.question_answering.stuff_prompt import PROMPT_SELECTOR


[docs]class BaseRetrievalQA(Chain): """用于问答链的基类。""" combine_documents_chain: BaseCombineDocumentsChain """用于合并文档的链。""" input_key: str = "query" #: :meta private: output_key: str = "result" #: :meta private: return_source_documents: bool = False """返回源文件或不返回。""" class Config: """这个pydantic对象的配置。""" extra = Extra.forbid arbitrary_types_allowed = True allow_population_by_field_name = True @property def input_keys(self) -> List[str]: """输入键。 :元数据 私密: """ return [self.input_key] @property def output_keys(self) -> List[str]: """输出键。 :元数据 私有: """ _output_keys = [self.output_key] if self.return_source_documents: _output_keys = _output_keys + ["source_documents"] return _output_keys
[docs] @classmethod def from_llm( cls, llm: BaseLanguageModel, prompt: Optional[PromptTemplate] = None, callbacks: Callbacks = None, llm_chain_kwargs: Optional[dict] = None, **kwargs: Any, ) -> BaseRetrievalQA: """从LLM初始化。""" _prompt = prompt or PROMPT_SELECTOR.get_prompt(llm) llm_chain = LLMChain( llm=llm, prompt=_prompt, callbacks=callbacks, **(llm_chain_kwargs or {}) ) document_prompt = PromptTemplate( input_variables=["page_content"], template="Context:\n{page_content}" ) combine_documents_chain = StuffDocumentsChain( llm_chain=llm_chain, document_variable_name="context", document_prompt=document_prompt, callbacks=callbacks, ) return cls( combine_documents_chain=combine_documents_chain, callbacks=callbacks, **kwargs, )
[docs] @classmethod def from_chain_type( cls, llm: BaseLanguageModel, chain_type: str = "stuff", chain_type_kwargs: Optional[dict] = None, **kwargs: Any, ) -> BaseRetrievalQA: """从链类型加载链。""" _chain_type_kwargs = chain_type_kwargs or {} combine_documents_chain = load_qa_chain( llm, chain_type=chain_type, **_chain_type_kwargs ) return cls(combine_documents_chain=combine_documents_chain, **kwargs)
@abstractmethod def _get_docs( self, question: str, *, run_manager: CallbackManagerForChainRun, ) -> List[Document]: """获取用于问答的文档。""" def _call( self, inputs: Dict[str, Any], run_manager: Optional[CallbackManagerForChainRun] = None, ) -> Dict[str, Any]: """运行get_relevant_text和llm来处理输入查询。 如果chain的'return_source_documents'为'True',则在键'source_documents'下返回检索到的文档。 示例: .. code-block:: python res = indexqa({'query': 'This is my query'}) answer, docs = res['result'], res['source_documents'] """ _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() question = inputs[self.input_key] accepts_run_manager = ( "run_manager" in inspect.signature(self._get_docs).parameters ) if accepts_run_manager: docs = self._get_docs(question, run_manager=_run_manager) else: docs = self._get_docs(question) # type: ignore[call-arg] answer = self.combine_documents_chain.run( input_documents=docs, question=question, callbacks=_run_manager.get_child() ) if self.return_source_documents: return {self.output_key: answer, "source_documents": docs} else: return {self.output_key: answer} @abstractmethod async def _aget_docs( self, question: str, *, run_manager: AsyncCallbackManagerForChainRun, ) -> List[Document]: """获取用于问答的文档。""" async def _acall( self, inputs: Dict[str, Any], run_manager: Optional[AsyncCallbackManagerForChainRun] = None, ) -> Dict[str, Any]: """运行get_relevant_text和llm来处理输入查询。 如果chain的'return_source_documents'为'True',则在键'source_documents'下返回检索到的文档。 示例: .. code-block:: python res = indexqa({'query': 'This is my query'}) answer, docs = res['result'], res['source_documents'] """ _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() question = inputs[self.input_key] accepts_run_manager = ( "run_manager" in inspect.signature(self._aget_docs).parameters ) if accepts_run_manager: docs = await self._aget_docs(question, run_manager=_run_manager) else: docs = await self._aget_docs(question) # type: ignore[call-arg] answer = await self.combine_documents_chain.arun( input_documents=docs, question=question, callbacks=_run_manager.get_child() ) if self.return_source_documents: return {self.output_key: answer, "source_documents": docs} else: return {self.output_key: answer}
[docs]@deprecated(since="0.1.17", alternative="create_retrieval_chain", removal="0.3.0") class RetrievalQA(BaseRetrievalQA): """用于针对索引进行问答的链。 此类已被弃用。请参见以下示例实现,使用`create_retrieval_chain`: .. code-block:: python from langchain.chains import create_retrieval_chain from langchain.chains.combine_documents import create_stuff_documents_chain from langchain_core.prompts import ChatPromptTemplate from langchain_openai import ChatOpenAI retriever = ... # 您的检索器 llm = ChatOpenAI() system_prompt = ( "使用给定的上下文来回答问题。" "如果您不知道答案,请说您不知道。" "最多使用三个句子,答案要简洁。" "上下文:{context}" ) prompt = ChatPromptTemplate.from_messages( [ ("system", system_prompt), ("human", "{input}"), ] ) question_answer_chain = create_stuff_documents_chain(llm, prompt) chain = create_retrieval_chain(retriever, question_answer_chain) chain.invoke({"input": query}) 示例: .. code-block:: python from langchain_community.llms import OpenAI from langchain.chains import RetrievalQA from langchain_community.vectorstores import FAISS from langchain_core.vectorstores import VectorStoreRetriever retriever = VectorStoreRetriever(vectorstore=FAISS(...)) retrievalQA = RetrievalQA.from_llm(llm=OpenAI(), retriever=retriever)""" retriever: BaseRetriever = Field(exclude=True) def _get_docs( self, question: str, *, run_manager: CallbackManagerForChainRun, ) -> List[Document]: """获取文档。""" return self.retriever.invoke( question, config={"callbacks": run_manager.get_child()} ) async def _aget_docs( self, question: str, *, run_manager: AsyncCallbackManagerForChainRun, ) -> List[Document]: """获取文档。""" return await self.retriever.ainvoke( question, config={"callbacks": run_manager.get_child()} ) @property def _chain_type(self) -> str: """返回链的类型。""" return "retrieval_qa"
[docs]class VectorDBQA(BaseRetrievalQA): """用于针对向量数据库进行问答的链。""" vectorstore: VectorStore = Field(exclude=True, alias="vectorstore") """连接到的向量数据库。""" k: int = 4 """要查询的文档数量。""" search_type: str = "similarity" """在向量存储上使用的搜索类型。`similarity` 或 `mmr`。""" search_kwargs: Dict[str, Any] = Field(default_factory=dict) """额外的搜索参数。""" @root_validator() def raise_deprecation(cls, values: Dict) -> Dict: warnings.warn( "`VectorDBQA` is deprecated - " "please use `from langchain.chains import RetrievalQA`" ) return values @root_validator() def validate_search_type(cls, values: Dict) -> Dict: """验证搜索类型。""" if "search_type" in values: search_type = values["search_type"] if search_type not in ("similarity", "mmr"): raise ValueError(f"search_type of {search_type} not allowed.") return values def _get_docs( self, question: str, *, run_manager: CallbackManagerForChainRun, ) -> List[Document]: """获取文档。""" if self.search_type == "similarity": docs = self.vectorstore.similarity_search( question, k=self.k, **self.search_kwargs ) elif self.search_type == "mmr": docs = self.vectorstore.max_marginal_relevance_search( question, k=self.k, **self.search_kwargs ) else: raise ValueError(f"search_type of {self.search_type} not allowed.") return docs async def _aget_docs( self, question: str, *, run_manager: AsyncCallbackManagerForChainRun, ) -> List[Document]: """获取文档。""" raise NotImplementedError("VectorDBQA does not support async") @property def _chain_type(self) -> str: """返回链的类型。""" return "vector_db_qa"