Source code for langchain.retrievers.ensemble

"""
集成检索器,通过使用加权的倒数积分融合多个检索器的结果。
"""
import asyncio
from collections import defaultdict
from collections.abc import Hashable
from itertools import chain
from typing import (
    Any,
    Callable,
    Dict,
    Iterable,
    Iterator,
    List,
    Optional,
    TypeVar,
    cast,
)

from langchain_core.callbacks import (
    AsyncCallbackManagerForRetrieverRun,
    CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.load.dump import dumpd
from langchain_core.pydantic_v1 import root_validator
from langchain_core.retrievers import BaseRetriever, RetrieverLike
from langchain_core.runnables import RunnableConfig
from langchain_core.runnables.config import ensure_config, patch_config
from langchain_core.runnables.utils import (
    ConfigurableFieldSpec,
    get_unique_config_specs,
)

T = TypeVar("T")
H = TypeVar("H", bound=Hashable)


[docs]def unique_by_key(iterable: Iterable[T], key: Callable[[T], H]) -> Iterator[T]: """基于键函数,生成可迭代对象的唯一元素。 参数: iterable:要过滤的可迭代对象。 key:为每个元素返回可哈希键的函数。 生成: 基于键函数的可迭代对象的唯一元素。 """ seen = set() for e in iterable: if (k := key(e)) not in seen: seen.add(k) yield e
[docs]class EnsembleRetriever(BaseRetriever): """整合多个检索器的检索器。 使用排名融合。 参数: retrievers: 要整合的检索器列表。 weights: 与检索器对应的权重列表。默认为所有检索器的等权重。 c: 添加到排名的常数,控制高排名项目的重要性和对低排名项目的考虑之间的平衡。默认值为60。""" retrievers: List[RetrieverLike] weights: List[float] c: int = 60 @property def config_specs(self) -> List[ConfigurableFieldSpec]: """列出此可运行程序的可配置字段。""" return get_unique_config_specs( spec for retriever in self.retrievers for spec in retriever.config_specs ) @root_validator(pre=True) def set_weights(cls, values: Dict[str, Any]) -> Dict[str, Any]: if not values.get("weights"): n_retrievers = len(values["retrievers"]) values["weights"] = [1 / n_retrievers] * n_retrievers return values
[docs] def invoke( self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> List[Document]: from langchain_core.callbacks import CallbackManager config = ensure_config(config) callback_manager = CallbackManager.configure( config.get("callbacks"), None, verbose=kwargs.get("verbose", False), inheritable_tags=config.get("tags", []), local_tags=self.tags, inheritable_metadata=config.get("metadata", {}), local_metadata=self.metadata, ) run_manager = callback_manager.on_retriever_start( dumpd(self), input, name=config.get("run_name"), **kwargs, ) try: result = self.rank_fusion(input, run_manager=run_manager, config=config) except Exception as e: run_manager.on_retriever_error(e) raise e else: run_manager.on_retriever_end( result, **kwargs, ) return result
[docs] async def ainvoke( self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> List[Document]: from langchain_core.callbacks import AsyncCallbackManager config = ensure_config(config) callback_manager = AsyncCallbackManager.configure( config.get("callbacks"), None, verbose=kwargs.get("verbose", False), inheritable_tags=config.get("tags", []), local_tags=self.tags, inheritable_metadata=config.get("metadata", {}), local_metadata=self.metadata, ) run_manager = await callback_manager.on_retriever_start( dumpd(self), input, name=config.get("run_name"), **kwargs, ) try: result = await self.arank_fusion( input, run_manager=run_manager, config=config ) except Exception as e: await run_manager.on_retriever_error(e) raise e else: await run_manager.on_retriever_end( result, **kwargs, ) return result
def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun, ) -> List[Document]: """获取给定查询的相关文档。 参数: query:要搜索的查询。 返回: 重新排序后的文档列表。 """ # Get fused result of the retrievers. fused_documents = self.rank_fusion(query, run_manager) return fused_documents async def _aget_relevant_documents( self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun, ) -> List[Document]: """异步获取给定查询的相关文档。 参数: query: 要搜索的查询。 返回: 重新排序的文档列表。 """ # Get fused result of the retrievers. fused_documents = await self.arank_fusion(query, run_manager) return fused_documents
[docs] def rank_fusion( self, query: str, run_manager: CallbackManagerForRetrieverRun, *, config: Optional[RunnableConfig] = None, ) -> List[Document]: """获取检索器的结果,并使用rank_fusion_func获取最终结果。 参数: query: 要搜索的查询。 返回: 重新排序的文档列表。 """ # Get the results of all retrievers. retriever_docs = [ retriever.invoke( query, patch_config( config, callbacks=run_manager.get_child(tag=f"retriever_{i+1}") ), ) for i, retriever in enumerate(self.retrievers) ] # Enforce that retrieved docs are Documents for each list in retriever_docs for i in range(len(retriever_docs)): retriever_docs[i] = [ Document(page_content=cast(str, doc)) if isinstance(doc, str) else doc for doc in retriever_docs[i] ] # apply rank fusion fused_documents = self.weighted_reciprocal_rank(retriever_docs) return fused_documents
[docs] async def arank_fusion( self, query: str, run_manager: AsyncCallbackManagerForRetrieverRun, *, config: Optional[RunnableConfig] = None, ) -> List[Document]: """异步地检索检索器的结果,并使用rank_fusion_func获取最终结果。 参数: query: 要搜索的查询。 返回: 重新排序的文档列表。 """ # Get the results of all retrievers. retriever_docs = await asyncio.gather( *[ retriever.ainvoke( query, patch_config( config, callbacks=run_manager.get_child(tag=f"retriever_{i+1}") ), ) for i, retriever in enumerate(self.retrievers) ] ) # Enforce that retrieved docs are Documents for each list in retriever_docs for i in range(len(retriever_docs)): retriever_docs[i] = [ Document(page_content=doc) if not isinstance(doc, Document) else doc # type: ignore[arg-type] for doc in retriever_docs[i] ] # apply rank fusion fused_documents = self.weighted_reciprocal_rank(retriever_docs) return fused_documents
[docs] def weighted_reciprocal_rank( self, doc_lists: List[List[Document]] ) -> List[Document]: """对多个排名列表执行加权倒数秩和融合。 您可以在这里找到有关RRF的更多详细信息: https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf 参数: doc_lists:一个排名列表的列表,其中每个排名列表包含唯一的项目。 返回: list:按加权RRF分数降序排序的最终聚合项目列表。 """ if len(doc_lists) != len(self.weights): raise ValueError( "Number of rank lists must be equal to the number of weights." ) # Associate each doc's content with its RRF score for later sorting by it # Duplicated contents across retrievers are collapsed & scored cumulatively rrf_score: Dict[str, float] = defaultdict(float) for doc_list, weight in zip(doc_lists, self.weights): for rank, doc in enumerate(doc_list, start=1): rrf_score[doc.page_content] += weight / (rank + self.c) # Docs are deduplicated by their contents then sorted by their scores all_docs = chain.from_iterable(doc_lists) sorted_docs = sorted( unique_by_key(all_docs, lambda doc: doc.page_content), reverse=True, key=lambda doc: rrf_score[doc.page_content], ) return sorted_docs