Source code for langchain_community.chains.pebblo_retrieval.base

"""
使用身份和语义强制执行的Pebblo检索链,用于针对向量数据库进行问答。
"""

import inspect
from typing import Any, Dict, List, Optional

from langchain.chains.base import Chain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain_core.callbacks import (
    AsyncCallbackManagerForChainRun,
    CallbackManagerForChainRun,
)
from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
from langchain_core.pydantic_v1 import Extra, Field, validator
from langchain_core.vectorstores import VectorStoreRetriever

from langchain_community.chains.pebblo_retrieval.enforcement_filters import (
    SUPPORTED_VECTORSTORES,
    set_enforcement_filters,
)
from langchain_community.chains.pebblo_retrieval.models import (
    AuthContext,
    SemanticContext,
)


[docs]class PebbloRetrievalQA(Chain): """使用身份和语义强制执行的检索链,用于针对向量数据库的问答。""" combine_documents_chain: BaseCombineDocumentsChain """用于合并文档的链。""" input_key: str = "query" #: :meta private: output_key: str = "result" #: :meta private: return_source_documents: bool = False """返回源文件或不返回。""" retriever: VectorStoreRetriever = Field(exclude=True) """用于检索的向量存储。""" auth_context_key: str = "auth_context" #: :meta private: """身份验证上下文用于身份强制。""" semantic_context_key: str = "semantic_context" #: :meta private: """语义执行的语义上下文。""" def _call( self, inputs: Dict[str, Any], run_manager: Optional[CallbackManagerForChainRun] = None, ) -> Dict[str, Any]: """运行get_relevant_text和llm来处理输入查询。 如果chain的'return_source_documents'为'True',则在键'source_documents'下返回检索到的文档。 示例: .. code-block:: python res = indexqa({'query': 'This is my query'}) answer, docs = res['result'], res['source_documents'] """ _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() question = inputs[self.input_key] auth_context = inputs.get(self.auth_context_key) semantic_context = inputs.get(self.semantic_context_key) accepts_run_manager = ( "run_manager" in inspect.signature(self._get_docs).parameters ) if accepts_run_manager: docs = self._get_docs( question, auth_context, semantic_context, run_manager=_run_manager ) else: docs = self._get_docs(question, auth_context, semantic_context) # type: ignore[call-arg] answer = self.combine_documents_chain.run( input_documents=docs, question=question, callbacks=_run_manager.get_child() ) if self.return_source_documents: return {self.output_key: answer, "source_documents": docs} else: return {self.output_key: answer} async def _acall( self, inputs: Dict[str, Any], run_manager: Optional[AsyncCallbackManagerForChainRun] = None, ) -> Dict[str, Any]: """运行get_relevant_text和llm来处理输入查询。 如果chain的'return_source_documents'为'True',则在键'source_documents'下返回检索到的文档。 示例: .. code-block:: python res = indexqa({'query': 'This is my query'}) answer, docs = res['result'], res['source_documents'] """ _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() question = inputs[self.input_key] auth_context = inputs.get(self.auth_context_key) semantic_context = inputs.get(self.semantic_context_key) accepts_run_manager = ( "run_manager" in inspect.signature(self._aget_docs).parameters ) if accepts_run_manager: docs = await self._aget_docs( question, auth_context, semantic_context, run_manager=_run_manager ) else: docs = await self._aget_docs(question, auth_context, semantic_context) # type: ignore[call-arg] answer = await self.combine_documents_chain.arun( input_documents=docs, question=question, callbacks=_run_manager.get_child() ) if self.return_source_documents: return {self.output_key: answer, "source_documents": docs} else: return {self.output_key: answer} class Config: """此pydantic对象的配置。""" extra = Extra.forbid arbitrary_types_allowed = True allow_population_by_field_name = True @property def input_keys(self) -> List[str]: """输入键。 :元数据 私有: """ return [self.input_key, self.auth_context_key, self.semantic_context_key] @property def output_keys(self) -> List[str]: """输出键。 :元数据 私有: """ _output_keys = [self.output_key] if self.return_source_documents: _output_keys += ["source_documents"] return _output_keys @property def _chain_type(self) -> str: """返回链的类型。""" return "pebblo_retrieval_qa"
[docs] @classmethod def from_chain_type( cls, llm: BaseLanguageModel, chain_type: str = "stuff", chain_type_kwargs: Optional[dict] = None, **kwargs: Any, ) -> "PebbloRetrievalQA": """从链类型加载链。""" from langchain.chains.question_answering import load_qa_chain _chain_type_kwargs = chain_type_kwargs or {} combine_documents_chain = load_qa_chain( llm, chain_type=chain_type, **_chain_type_kwargs ) return cls(combine_documents_chain=combine_documents_chain, **kwargs)
@validator("retriever", pre=True, always=True) def validate_vectorstore( cls, retriever: VectorStoreRetriever ) -> VectorStoreRetriever: """ 验证检索器的向量存储是否支持向量存储。 """ if not any( isinstance(retriever.vectorstore, supported_class) for supported_class in SUPPORTED_VECTORSTORES ): raise ValueError( f"Vectorstore must be an instance of one of the supported " f"vectorstores: {SUPPORTED_VECTORSTORES}. " f"Got {type(retriever.vectorstore).__name__} instead." ) return retriever def _get_docs( self, question: str, auth_context: Optional[AuthContext], semantic_context: Optional[SemanticContext], *, run_manager: CallbackManagerForChainRun, ) -> List[Document]: """获取文档。""" set_enforcement_filters(self.retriever, auth_context, semantic_context) return self.retriever.get_relevant_documents( question, callbacks=run_manager.get_child() ) async def _aget_docs( self, question: str, auth_context: Optional[AuthContext], semantic_context: Optional[SemanticContext], *, run_manager: AsyncCallbackManagerForChainRun, ) -> List[Document]: """获取文档。""" set_enforcement_filters(self.retriever, auth_context, semantic_context) return await self.retriever.aget_relevant_documents( question, callbacks=run_manager.get_child() )