Source code for langchain.retrievers.contextual_compression

from typing import Any, List

from langchain_core.callbacks import (
    AsyncCallbackManagerForRetrieverRun,
    CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever

from langchain.retrievers.document_compressors.base import (
    BaseDocumentCompressor,
)


[docs]class ContextualCompressionRetriever(BaseRetriever): """包装基础检索器并压缩结果的检索器。""" base_compressor: BaseDocumentCompressor """用于压缩检索到的文档的压缩器。""" base_retriever: BaseRetriever """用于获取相关文档的基本检索器。""" class Config: """这个pydantic对象的配置。""" arbitrary_types_allowed = True def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any, ) -> List[Document]: """获取与查询相关的文档。 参数: query:要查找相关文档的字符串 返回: 相关文档的序列 """ docs = self.base_retriever.invoke( query, config={"callbacks": run_manager.get_child()}, **kwargs ) if docs: compressed_docs = self.base_compressor.compress_documents( docs, query, callbacks=run_manager.get_child() ) return list(compressed_docs) else: return [] async def _aget_relevant_documents( self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun, **kwargs: Any, ) -> List[Document]: """获取与查询相关的文档。 参数: 查询:要查找相关文档的字符串 返回: 相关文档的列表 """ docs = await self.base_retriever.ainvoke( query, config={"callbacks": run_manager.get_child()}, **kwargs ) if docs: compressed_docs = await self.base_compressor.acompress_documents( docs, query, callbacks=run_manager.get_child() ) return list(compressed_docs) else: return []