"""用于针对向量数据库进行问答的链。"""
from __future__ import annotations
import inspect
import warnings
from abc import abstractmethod
from typing import Any, Dict, List, Optional
from langchain_core._api import deprecated
from langchain_core.callbacks import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
Callbacks,
)
from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import PromptTemplate
from langchain_core.pydantic_v1 import Extra, Field, root_validator
from langchain_core.retrievers import BaseRetriever
from langchain_core.vectorstores import VectorStore
from langchain.chains.base import Chain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.chains.question_answering import load_qa_chain
from langchain.chains.question_answering.stuff_prompt import PROMPT_SELECTOR
[docs]class BaseRetrievalQA(Chain):
"""用于问答链的基类。"""
combine_documents_chain: BaseCombineDocumentsChain
"""用于合并文档的链。"""
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
return_source_documents: bool = False
"""返回源文件或不返回。"""
class Config:
"""这个pydantic对象的配置。"""
extra = Extra.forbid
arbitrary_types_allowed = True
allow_population_by_field_name = True
@property
def input_keys(self) -> List[str]:
"""输入键。
:元数据 私密:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""输出键。
:元数据 私有:
"""
_output_keys = [self.output_key]
if self.return_source_documents:
_output_keys = _output_keys + ["source_documents"]
return _output_keys
[docs] @classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
prompt: Optional[PromptTemplate] = None,
callbacks: Callbacks = None,
llm_chain_kwargs: Optional[dict] = None,
**kwargs: Any,
) -> BaseRetrievalQA:
"""从LLM初始化。"""
_prompt = prompt or PROMPT_SELECTOR.get_prompt(llm)
llm_chain = LLMChain(
llm=llm, prompt=_prompt, callbacks=callbacks, **(llm_chain_kwargs or {})
)
document_prompt = PromptTemplate(
input_variables=["page_content"], template="Context:\n{page_content}"
)
combine_documents_chain = StuffDocumentsChain(
llm_chain=llm_chain,
document_variable_name="context",
document_prompt=document_prompt,
callbacks=callbacks,
)
return cls(
combine_documents_chain=combine_documents_chain,
callbacks=callbacks,
**kwargs,
)
[docs] @classmethod
def from_chain_type(
cls,
llm: BaseLanguageModel,
chain_type: str = "stuff",
chain_type_kwargs: Optional[dict] = None,
**kwargs: Any,
) -> BaseRetrievalQA:
"""从链类型加载链。"""
_chain_type_kwargs = chain_type_kwargs or {}
combine_documents_chain = load_qa_chain(
llm, chain_type=chain_type, **_chain_type_kwargs
)
return cls(combine_documents_chain=combine_documents_chain, **kwargs)
@abstractmethod
def _get_docs(
self,
question: str,
*,
run_manager: CallbackManagerForChainRun,
) -> List[Document]:
"""获取用于问答的文档。"""
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""运行get_relevant_text和llm来处理输入查询。
如果chain的'return_source_documents'为'True',则在键'source_documents'下返回检索到的文档。
示例:
.. code-block:: python
res = indexqa({'query': 'This is my query'})
answer, docs = res['result'], res['source_documents']
"""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
question = inputs[self.input_key]
accepts_run_manager = (
"run_manager" in inspect.signature(self._get_docs).parameters
)
if accepts_run_manager:
docs = self._get_docs(question, run_manager=_run_manager)
else:
docs = self._get_docs(question) # type: ignore[call-arg]
answer = self.combine_documents_chain.run(
input_documents=docs, question=question, callbacks=_run_manager.get_child()
)
if self.return_source_documents:
return {self.output_key: answer, "source_documents": docs}
else:
return {self.output_key: answer}
@abstractmethod
async def _aget_docs(
self,
question: str,
*,
run_manager: AsyncCallbackManagerForChainRun,
) -> List[Document]:
"""获取用于问答的文档。"""
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""运行get_relevant_text和llm来处理输入查询。
如果chain的'return_source_documents'为'True',则在键'source_documents'下返回检索到的文档。
示例:
.. code-block:: python
res = indexqa({'query': 'This is my query'})
answer, docs = res['result'], res['source_documents']
"""
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
question = inputs[self.input_key]
accepts_run_manager = (
"run_manager" in inspect.signature(self._aget_docs).parameters
)
if accepts_run_manager:
docs = await self._aget_docs(question, run_manager=_run_manager)
else:
docs = await self._aget_docs(question) # type: ignore[call-arg]
answer = await self.combine_documents_chain.arun(
input_documents=docs, question=question, callbacks=_run_manager.get_child()
)
if self.return_source_documents:
return {self.output_key: answer, "source_documents": docs}
else:
return {self.output_key: answer}
[docs]@deprecated(since="0.1.17", alternative="create_retrieval_chain", removal="0.3.0")
class RetrievalQA(BaseRetrievalQA):
"""用于针对索引进行问答的链。
此类已被弃用。请参见以下示例实现,使用`create_retrieval_chain`:
.. code-block:: python
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
retriever = ... # 您的检索器
llm = ChatOpenAI()
system_prompt = (
"使用给定的上下文来回答问题。"
"如果您不知道答案,请说您不知道。"
"最多使用三个句子,答案要简洁。"
"上下文:{context}"
)
prompt = ChatPromptTemplate.from_messages(
[
("system", system_prompt),
("human", "{input}"),
]
)
question_answer_chain = create_stuff_documents_chain(llm, prompt)
chain = create_retrieval_chain(retriever, question_answer_chain)
chain.invoke({"input": query})
示例:
.. code-block:: python
from langchain_community.llms import OpenAI
from langchain.chains import RetrievalQA
from langchain_community.vectorstores import FAISS
from langchain_core.vectorstores import VectorStoreRetriever
retriever = VectorStoreRetriever(vectorstore=FAISS(...))
retrievalQA = RetrievalQA.from_llm(llm=OpenAI(), retriever=retriever)"""
retriever: BaseRetriever = Field(exclude=True)
def _get_docs(
self,
question: str,
*,
run_manager: CallbackManagerForChainRun,
) -> List[Document]:
"""获取文档。"""
return self.retriever.invoke(
question, config={"callbacks": run_manager.get_child()}
)
async def _aget_docs(
self,
question: str,
*,
run_manager: AsyncCallbackManagerForChainRun,
) -> List[Document]:
"""获取文档。"""
return await self.retriever.ainvoke(
question, config={"callbacks": run_manager.get_child()}
)
@property
def _chain_type(self) -> str:
"""返回链的类型。"""
return "retrieval_qa"
[docs]class VectorDBQA(BaseRetrievalQA):
"""用于针对向量数据库进行问答的链。"""
vectorstore: VectorStore = Field(exclude=True, alias="vectorstore")
"""连接到的向量数据库。"""
k: int = 4
"""要查询的文档数量。"""
search_type: str = "similarity"
"""在向量存储上使用的搜索类型。`similarity` 或 `mmr`。"""
search_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""额外的搜索参数。"""
@root_validator()
def raise_deprecation(cls, values: Dict) -> Dict:
warnings.warn(
"`VectorDBQA` is deprecated - "
"please use `from langchain.chains import RetrievalQA`"
)
return values
@root_validator()
def validate_search_type(cls, values: Dict) -> Dict:
"""验证搜索类型。"""
if "search_type" in values:
search_type = values["search_type"]
if search_type not in ("similarity", "mmr"):
raise ValueError(f"search_type of {search_type} not allowed.")
return values
def _get_docs(
self,
question: str,
*,
run_manager: CallbackManagerForChainRun,
) -> List[Document]:
"""获取文档。"""
if self.search_type == "similarity":
docs = self.vectorstore.similarity_search(
question, k=self.k, **self.search_kwargs
)
elif self.search_type == "mmr":
docs = self.vectorstore.max_marginal_relevance_search(
question, k=self.k, **self.search_kwargs
)
else:
raise ValueError(f"search_type of {self.search_type} not allowed.")
return docs
async def _aget_docs(
self,
question: str,
*,
run_manager: AsyncCallbackManagerForChainRun,
) -> List[Document]:
"""获取文档。"""
raise NotImplementedError("VectorDBQA does not support async")
@property
def _chain_type(self) -> str:
"""返回链的类型。"""
return "vector_db_qa"