""" **Retriever** 类根据文本 **query** 返回文档。
它比向量存储更通用。检索器不需要能够存储文档,只需要返回(或检索)文档。向量存储可以用作检索器的支柱,但也有其他类型的检索器。
**类层次结构:**
.. code-block::
BaseRetriever --> <name>Retriever # 例如:ArxivRetriever, MergerRetriever
**主要辅助功能:**
.. code-block::
RetrieverInput, RetrieverOutput, RetrieverLike, RetrieverOutputLike,
Document, Serializable, Callbacks,
CallbackManagerForRetrieverRun, AsyncCallbackManagerForRetrieverRun
"""
from __future__ import annotations
import warnings
from abc import ABC, abstractmethod
from inspect import signature
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from langchain_core._api import deprecated
from langchain_core.documents import Document
from langchain_core.load.dump import dumpd
from langchain_core.runnables import (
Runnable,
RunnableConfig,
RunnableSerializable,
ensure_config,
)
from langchain_core.runnables.config import run_in_executor
if TYPE_CHECKING:
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
Callbacks,
)
RetrieverInput = str
RetrieverOutput = List[Document]
RetrieverLike = Runnable[RetrieverInput, RetrieverOutput]
RetrieverOutputLike = Runnable[Any, RetrieverOutput]
[docs]class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
"""抽象基类,用于文档检索系统。
检索系统被定义为可以接受字符串查询并从某个来源返回最相关的文档的东西。
用法:
检索器遵循标准的Runnable接口,应通过标准的runnable方法`invoke`,`ainvoke`,`batch`,`abatch`来使用。
实现:
当实现自定义检索器时,该类应实现`_get_relevant_documents`方法来定义检索文档的逻辑。
可选地,可以通过覆盖`_aget_relevant_documents`方法来提供异步本机实现。
示例:一个从文档列表中返回前5个文档的检索器
.. code-block:: python
from langchain_core import Document, BaseRetriever
from typing import List
class SimpleRetriever(BaseRetriever):
docs: List[Document]
k: int = 5
def _get_relevant_documents(self, query: str) -> List[Document]:
\"\"\"从文档列表中返回前k个文档\"\"\"
return self.docs[:self.k]
async def _aget_relevant_documents(self, query: str) -> List[Document]:
\"\"\"(可选)异步本机实现。\"\"\"
return self.docs[:self.k]
示例:基于scitkit learn向量化器的简单检索器
.. code-block:: python
from sklearn.metrics.pairwise import cosine_similarity
class TFIDFRetriever(BaseRetriever, BaseModel):
vectorizer: Any
docs: List[Document]
tfidf_array: Any
k: int = 4
class Config:
arbitrary_types_allowed = True
def _get_relevant_documents(self, query: str) -> List[Document]:
# Ip -- (n_docs,x), Op -- (n_docs,n_Feats)
query_vec = self.vectorizer.transform([query])
# Op -- (n_docs,1) -- Cosine Sim with each doc
results = cosine_similarity(self.tfidf_array, query_vec).reshape((-1,))
return [self.docs[i] for i in results.argsort()[-self.k :][::-1]]
""" # noqa: E501
class Config:
"""此pydantic对象的配置。"""
arbitrary_types_allowed = True
_new_arg_supported: bool = False
_expects_other_args: bool = False
tags: Optional[List[str]] = None
"""可选的标签列表,与检索器相关联。默认为None
这些标签将与对该检索器的每次调用相关联,并作为参数传递给在`callbacks`中定义的处理程序。
您可以使用这些标签来识别检索器的特定实例及其用例。"""
metadata: Optional[Dict[str, Any]] = None
"""与检索器关联的可选元数据。默认为None
此元数据将与每次调用此检索器关联,
并作为参数传递给在`callbacks`中定义的处理程序。
您可以使用这些元数据来识别检索器的特定实例及其用例。"""
def __init_subclass__(cls, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
# Version upgrade for old retrievers that implemented the public
# methods directly.
if cls.get_relevant_documents != BaseRetriever.get_relevant_documents:
warnings.warn(
"Retrievers must implement abstract `_get_relevant_documents` method"
" instead of `get_relevant_documents`",
DeprecationWarning,
)
swap = cls.get_relevant_documents
cls.get_relevant_documents = ( # type: ignore[assignment]
BaseRetriever.get_relevant_documents
)
cls._get_relevant_documents = swap # type: ignore[assignment]
if (
hasattr(cls, "aget_relevant_documents")
and cls.aget_relevant_documents != BaseRetriever.aget_relevant_documents
):
warnings.warn(
"Retrievers must implement abstract `_aget_relevant_documents` method"
" instead of `aget_relevant_documents`",
DeprecationWarning,
)
aswap = cls.aget_relevant_documents
cls.aget_relevant_documents = ( # type: ignore[assignment]
BaseRetriever.aget_relevant_documents
)
cls._aget_relevant_documents = aswap # type: ignore[assignment]
parameters = signature(cls._get_relevant_documents).parameters
cls._new_arg_supported = parameters.get("run_manager") is not None
# If a V1 retriever broke the interface and expects additional arguments
cls._expects_other_args = (
len(set(parameters.keys()) - {"self", "query", "run_manager"}) > 0
)
[docs] def invoke(
self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> List[Document]:
"""调用检索器以获取相关文档。
同步检索器调用的主要入口点。
参数:
input: 查询字符串
config: 检索器的配置
**kwargs: 传递给检索器的其他参数
返回:
相关文档的列表
示例:
.. code-block:: python
retriever.invoke("query")
"""
config = ensure_config(config)
return self.get_relevant_documents(
input,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
**kwargs,
)
[docs] async def ainvoke(
self,
input: str,
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> List[Document]:
"""异步调用检索器以获取相关文档。
异步检索器调用的主要入口点。
参数:
input: 查询字符串
config: 检索器的配置
**kwargs: 传递给检索器的额外参数
返回:
相关文档的列表
示例:
.. code-block:: python
await retriever.ainvoke("query")
"""
config = ensure_config(config)
return await self.aget_relevant_documents(
input,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
**kwargs,
)
@abstractmethod
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""获取与查询相关的文档。
参数:
query:要查找相关文档的字符串
run_manager:要使用的回调处理程序
返回:
相关文档的列表
"""
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
"""异步获取与查询相关的文档。
参数:
query: 要查找相关文档的字符串
run_manager: 要使用的回调处理程序
返回:
相关文档的列表
"""
return await run_in_executor(
None,
self._get_relevant_documents,
query,
run_manager=run_manager.get_sync(),
)
[docs] @deprecated(since="0.1.46", alternative="invoke", removal="0.3.0")
def get_relevant_documents(
self,
query: str,
*,
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
**kwargs: Any,
) -> List[Document]:
"""检索与查询相关的文档。
用户应该更倾向于使用`.invoke`或`.batch`,而不是直接使用`get_relevant_documents`。
参数:
query: 要查找相关文档的字符串
callbacks: 回调管理器或回调列表
tags: 可选的与检索器相关联的标签列表。默认为None
这些标签将与每次调用此检索器相关联,
并作为参数传递给在`callbacks`中定义的处理程序。
metadata: 与检索器相关联的可选元数据。默认为None
此元数据将与每次调用此检索器相关联,
并作为参数传递给在`callbacks`中定义的处理程序。
run_name: 运行的可选名称。
返回:
相关文档列表
"""
from langchain_core.callbacks.manager import CallbackManager
callback_manager = CallbackManager.configure(
callbacks,
None,
verbose=kwargs.get("verbose", False),
inheritable_tags=tags,
local_tags=self.tags,
inheritable_metadata=metadata,
local_metadata=self.metadata,
)
run_manager = callback_manager.on_retriever_start(
dumpd(self),
query,
name=run_name,
run_id=kwargs.pop("run_id", None),
)
try:
_kwargs = kwargs if self._expects_other_args else {}
if self._new_arg_supported:
result = self._get_relevant_documents(
query, run_manager=run_manager, **_kwargs
)
else:
result = self._get_relevant_documents(query, **_kwargs)
except Exception as e:
run_manager.on_retriever_error(e)
raise e
else:
run_manager.on_retriever_end(
result,
)
return result
[docs] @deprecated(since="0.1.46", alternative="ainvoke", removal="0.3.0")
async def aget_relevant_documents(
self,
query: str,
*,
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
**kwargs: Any,
) -> List[Document]:
"""异步获取与查询相关的文档。
用户应该更倾向于使用`.ainvoke`或`.abatch`,而不是直接使用`aget_relevant_documents`。
参数:
query: 要查找相关文档的字符串
callbacks: 回调管理器或回调列表
tags: 与检索器相关的可选标签列表。默认为None
这些标签将与每次调用此检索器相关联,
并作为参数传递给在`callbacks`中定义的处理程序。
metadata: 与检索器相关的可选元数据。默认为None
这些元数据将与每次调用此检索器相关联,
并作为参数传递给在`callbacks`中定义的处理程序。
run_name: 运行的可选名称。
返回:
相关文档列表
"""
from langchain_core.callbacks.manager import AsyncCallbackManager
callback_manager = AsyncCallbackManager.configure(
callbacks,
None,
verbose=kwargs.get("verbose", False),
inheritable_tags=tags,
local_tags=self.tags,
inheritable_metadata=metadata,
local_metadata=self.metadata,
)
run_manager = await callback_manager.on_retriever_start(
dumpd(self),
query,
name=run_name,
run_id=kwargs.pop("run_id", None),
)
try:
_kwargs = kwargs if self._expects_other_args else {}
if self._new_arg_supported:
result = await self._aget_relevant_documents(
query, run_manager=run_manager, **_kwargs
)
else:
result = await self._aget_relevant_documents(query, **_kwargs)
except Exception as e:
await run_manager.on_retriever_error(e)
raise e
else:
await run_manager.on_retriever_end(
result,
)
return result