"""
使用身份和语义强制执行的Pebblo检索链,用于针对向量数据库进行问答。
"""
import inspect
from typing import Any, Dict, List, Optional
from langchain.chains.base import Chain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain_core.callbacks import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
from langchain_core.pydantic_v1 import Extra, Field, validator
from langchain_core.vectorstores import VectorStoreRetriever
from langchain_community.chains.pebblo_retrieval.enforcement_filters import (
SUPPORTED_VECTORSTORES,
set_enforcement_filters,
)
from langchain_community.chains.pebblo_retrieval.models import (
AuthContext,
SemanticContext,
)
[docs]class PebbloRetrievalQA(Chain):
"""使用身份和语义强制执行的检索链,用于针对向量数据库的问答。"""
combine_documents_chain: BaseCombineDocumentsChain
"""用于合并文档的链。"""
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
return_source_documents: bool = False
"""返回源文件或不返回。"""
retriever: VectorStoreRetriever = Field(exclude=True)
"""用于检索的向量存储。"""
auth_context_key: str = "auth_context" #: :meta private:
"""身份验证上下文用于身份强制。"""
semantic_context_key: str = "semantic_context" #: :meta private:
"""语义执行的语义上下文。"""
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""运行get_relevant_text和llm来处理输入查询。
如果chain的'return_source_documents'为'True',则在键'source_documents'下返回检索到的文档。
示例:
.. code-block:: python
res = indexqa({'query': 'This is my query'})
answer, docs = res['result'], res['source_documents']
"""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
question = inputs[self.input_key]
auth_context = inputs.get(self.auth_context_key)
semantic_context = inputs.get(self.semantic_context_key)
accepts_run_manager = (
"run_manager" in inspect.signature(self._get_docs).parameters
)
if accepts_run_manager:
docs = self._get_docs(
question, auth_context, semantic_context, run_manager=_run_manager
)
else:
docs = self._get_docs(question, auth_context, semantic_context) # type: ignore[call-arg]
answer = self.combine_documents_chain.run(
input_documents=docs, question=question, callbacks=_run_manager.get_child()
)
if self.return_source_documents:
return {self.output_key: answer, "source_documents": docs}
else:
return {self.output_key: answer}
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""运行get_relevant_text和llm来处理输入查询。
如果chain的'return_source_documents'为'True',则在键'source_documents'下返回检索到的文档。
示例:
.. code-block:: python
res = indexqa({'query': 'This is my query'})
answer, docs = res['result'], res['source_documents']
"""
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
question = inputs[self.input_key]
auth_context = inputs.get(self.auth_context_key)
semantic_context = inputs.get(self.semantic_context_key)
accepts_run_manager = (
"run_manager" in inspect.signature(self._aget_docs).parameters
)
if accepts_run_manager:
docs = await self._aget_docs(
question, auth_context, semantic_context, run_manager=_run_manager
)
else:
docs = await self._aget_docs(question, auth_context, semantic_context) # type: ignore[call-arg]
answer = await self.combine_documents_chain.arun(
input_documents=docs, question=question, callbacks=_run_manager.get_child()
)
if self.return_source_documents:
return {self.output_key: answer, "source_documents": docs}
else:
return {self.output_key: answer}
class Config:
"""此pydantic对象的配置。"""
extra = Extra.forbid
arbitrary_types_allowed = True
allow_population_by_field_name = True
@property
def input_keys(self) -> List[str]:
"""输入键。
:元数据 私有:
"""
return [self.input_key, self.auth_context_key, self.semantic_context_key]
@property
def output_keys(self) -> List[str]:
"""输出键。
:元数据 私有:
"""
_output_keys = [self.output_key]
if self.return_source_documents:
_output_keys += ["source_documents"]
return _output_keys
@property
def _chain_type(self) -> str:
"""返回链的类型。"""
return "pebblo_retrieval_qa"
[docs] @classmethod
def from_chain_type(
cls,
llm: BaseLanguageModel,
chain_type: str = "stuff",
chain_type_kwargs: Optional[dict] = None,
**kwargs: Any,
) -> "PebbloRetrievalQA":
"""从链类型加载链。"""
from langchain.chains.question_answering import load_qa_chain
_chain_type_kwargs = chain_type_kwargs or {}
combine_documents_chain = load_qa_chain(
llm, chain_type=chain_type, **_chain_type_kwargs
)
return cls(combine_documents_chain=combine_documents_chain, **kwargs)
@validator("retriever", pre=True, always=True)
def validate_vectorstore(
cls, retriever: VectorStoreRetriever
) -> VectorStoreRetriever:
"""
验证检索器的向量存储是否支持向量存储。
"""
if not any(
isinstance(retriever.vectorstore, supported_class)
for supported_class in SUPPORTED_VECTORSTORES
):
raise ValueError(
f"Vectorstore must be an instance of one of the supported "
f"vectorstores: {SUPPORTED_VECTORSTORES}. "
f"Got {type(retriever.vectorstore).__name__} instead."
)
return retriever
def _get_docs(
self,
question: str,
auth_context: Optional[AuthContext],
semantic_context: Optional[SemanticContext],
*,
run_manager: CallbackManagerForChainRun,
) -> List[Document]:
"""获取文档。"""
set_enforcement_filters(self.retriever, auth_context, semantic_context)
return self.retriever.get_relevant_documents(
question, callbacks=run_manager.get_child()
)
async def _aget_docs(
self,
question: str,
auth_context: Optional[AuthContext],
semantic_context: Optional[SemanticContext],
*,
run_manager: AsyncCallbackManagerForChainRun,
) -> List[Document]:
"""获取文档。"""
set_enforcement_filters(self.retriever, auth_context, semantic_context)
return await self.retriever.aget_relevant_documents(
question, callbacks=run_manager.get_child()
)