"""
集成检索器,通过使用加权的倒数积分融合多个检索器的结果。
"""
import asyncio
from collections import defaultdict
from collections.abc import Hashable
from itertools import chain
from typing import (
Any,
Callable,
Dict,
Iterable,
Iterator,
List,
Optional,
TypeVar,
cast,
)
from langchain_core.callbacks import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.load.dump import dumpd
from langchain_core.pydantic_v1 import root_validator
from langchain_core.retrievers import BaseRetriever, RetrieverLike
from langchain_core.runnables import RunnableConfig
from langchain_core.runnables.config import ensure_config, patch_config
from langchain_core.runnables.utils import (
ConfigurableFieldSpec,
get_unique_config_specs,
)
T = TypeVar("T")
H = TypeVar("H", bound=Hashable)
[docs]def unique_by_key(iterable: Iterable[T], key: Callable[[T], H]) -> Iterator[T]:
"""基于键函数,生成可迭代对象的唯一元素。
参数:
iterable:要过滤的可迭代对象。
key:为每个元素返回可哈希键的函数。
生成:
基于键函数的可迭代对象的唯一元素。
"""
seen = set()
for e in iterable:
if (k := key(e)) not in seen:
seen.add(k)
yield e
[docs]class EnsembleRetriever(BaseRetriever):
"""整合多个检索器的检索器。
使用排名融合。
参数:
retrievers: 要整合的检索器列表。
weights: 与检索器对应的权重列表。默认为所有检索器的等权重。
c: 添加到排名的常数,控制高排名项目的重要性和对低排名项目的考虑之间的平衡。默认值为60。"""
retrievers: List[RetrieverLike]
weights: List[float]
c: int = 60
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
"""列出此可运行程序的可配置字段。"""
return get_unique_config_specs(
spec for retriever in self.retrievers for spec in retriever.config_specs
)
@root_validator(pre=True)
def set_weights(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if not values.get("weights"):
n_retrievers = len(values["retrievers"])
values["weights"] = [1 / n_retrievers] * n_retrievers
return values
[docs] def invoke(
self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> List[Document]:
from langchain_core.callbacks import CallbackManager
config = ensure_config(config)
callback_manager = CallbackManager.configure(
config.get("callbacks"),
None,
verbose=kwargs.get("verbose", False),
inheritable_tags=config.get("tags", []),
local_tags=self.tags,
inheritable_metadata=config.get("metadata", {}),
local_metadata=self.metadata,
)
run_manager = callback_manager.on_retriever_start(
dumpd(self),
input,
name=config.get("run_name"),
**kwargs,
)
try:
result = self.rank_fusion(input, run_manager=run_manager, config=config)
except Exception as e:
run_manager.on_retriever_error(e)
raise e
else:
run_manager.on_retriever_end(
result,
**kwargs,
)
return result
[docs] async def ainvoke(
self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> List[Document]:
from langchain_core.callbacks import AsyncCallbackManager
config = ensure_config(config)
callback_manager = AsyncCallbackManager.configure(
config.get("callbacks"),
None,
verbose=kwargs.get("verbose", False),
inheritable_tags=config.get("tags", []),
local_tags=self.tags,
inheritable_metadata=config.get("metadata", {}),
local_metadata=self.metadata,
)
run_manager = await callback_manager.on_retriever_start(
dumpd(self),
input,
name=config.get("run_name"),
**kwargs,
)
try:
result = await self.arank_fusion(
input, run_manager=run_manager, config=config
)
except Exception as e:
await run_manager.on_retriever_error(e)
raise e
else:
await run_manager.on_retriever_end(
result,
**kwargs,
)
return result
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
) -> List[Document]:
"""获取给定查询的相关文档。
参数:
query:要搜索的查询。
返回:
重新排序后的文档列表。
"""
# Get fused result of the retrievers.
fused_documents = self.rank_fusion(query, run_manager)
return fused_documents
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
) -> List[Document]:
"""异步获取给定查询的相关文档。
参数:
query: 要搜索的查询。
返回:
重新排序的文档列表。
"""
# Get fused result of the retrievers.
fused_documents = await self.arank_fusion(query, run_manager)
return fused_documents
[docs] def rank_fusion(
self,
query: str,
run_manager: CallbackManagerForRetrieverRun,
*,
config: Optional[RunnableConfig] = None,
) -> List[Document]:
"""获取检索器的结果,并使用rank_fusion_func获取最终结果。
参数:
query: 要搜索的查询。
返回:
重新排序的文档列表。
"""
# Get the results of all retrievers.
retriever_docs = [
retriever.invoke(
query,
patch_config(
config, callbacks=run_manager.get_child(tag=f"retriever_{i+1}")
),
)
for i, retriever in enumerate(self.retrievers)
]
# Enforce that retrieved docs are Documents for each list in retriever_docs
for i in range(len(retriever_docs)):
retriever_docs[i] = [
Document(page_content=cast(str, doc)) if isinstance(doc, str) else doc
for doc in retriever_docs[i]
]
# apply rank fusion
fused_documents = self.weighted_reciprocal_rank(retriever_docs)
return fused_documents
[docs] async def arank_fusion(
self,
query: str,
run_manager: AsyncCallbackManagerForRetrieverRun,
*,
config: Optional[RunnableConfig] = None,
) -> List[Document]:
"""异步地检索检索器的结果,并使用rank_fusion_func获取最终结果。
参数:
query: 要搜索的查询。
返回:
重新排序的文档列表。
"""
# Get the results of all retrievers.
retriever_docs = await asyncio.gather(
*[
retriever.ainvoke(
query,
patch_config(
config, callbacks=run_manager.get_child(tag=f"retriever_{i+1}")
),
)
for i, retriever in enumerate(self.retrievers)
]
)
# Enforce that retrieved docs are Documents for each list in retriever_docs
for i in range(len(retriever_docs)):
retriever_docs[i] = [
Document(page_content=doc) if not isinstance(doc, Document) else doc # type: ignore[arg-type]
for doc in retriever_docs[i]
]
# apply rank fusion
fused_documents = self.weighted_reciprocal_rank(retriever_docs)
return fused_documents
[docs] def weighted_reciprocal_rank(
self, doc_lists: List[List[Document]]
) -> List[Document]:
"""对多个排名列表执行加权倒数秩和融合。
您可以在这里找到有关RRF的更多详细信息:
https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf
参数:
doc_lists:一个排名列表的列表,其中每个排名列表包含唯一的项目。
返回:
list:按加权RRF分数降序排序的最终聚合项目列表。
"""
if len(doc_lists) != len(self.weights):
raise ValueError(
"Number of rank lists must be equal to the number of weights."
)
# Associate each doc's content with its RRF score for later sorting by it
# Duplicated contents across retrievers are collapsed & scored cumulatively
rrf_score: Dict[str, float] = defaultdict(float)
for doc_list, weight in zip(doc_lists, self.weights):
for rank, doc in enumerate(doc_list, start=1):
rrf_score[doc.page_content] += weight / (rank + self.c)
# Docs are deduplicated by their contents then sorted by their scores
all_docs = chain.from_iterable(doc_lists)
sorted_docs = sorted(
unique_by_key(all_docs, lambda doc: doc.page_content),
reverse=True,
key=lambda doc: rrf_score[doc.page_content],
)
return sorted_docs