Source code for langchain_community.document_compressors.flashrank_rerank

from __future__ import annotations

from typing import TYPE_CHECKING, Dict, Optional, Sequence

from langchain_core.callbacks.manager import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.pydantic_v1 import Extra, root_validator

if TYPE_CHECKING:
    from flashrank import Ranker, RerankRequest
else:
    # Avoid pydantic annotation issues when actually instantiating
    # while keeping this import optional
    try:
        from flashrank import Ranker, RerankRequest
    except ImportError:
        pass

DEFAULT_MODEL_NAME = "ms-marco-MultiBERT-L-12"


[docs]class FlashrankRerank(BaseDocumentCompressor): """使用Flashrank接口的文档压缩器。""" client: Ranker """用于压缩文档的Flashrank客户端""" top_n: int = 3 """返回的文档数量。""" model: Optional[str] = None """用于重新排序的模型。""" class Config: """此pydantic对象的配置。""" extra = Extra.forbid arbitrary_types_allowed = True @root_validator(pre=True) def validate_environment(cls, values: Dict) -> Dict: """验证环境中是否存在API密钥和Python包。""" try: from flashrank import Ranker except ImportError: raise ImportError( "Could not import flashrank python package. " "Please install it with `pip install flashrank`." ) values["model"] = values.get("model", DEFAULT_MODEL_NAME) values["client"] = Ranker(model_name=values["model"]) return values
[docs] def compress_documents( self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None, ) -> Sequence[Document]: passages = [ {"id": i, "text": doc.page_content, "meta": doc.metadata} for i, doc in enumerate(documents) ] rerank_request = RerankRequest(query=query, passages=passages) rerank_response = self.client.rerank(rerank_request)[: self.top_n] final_results = [] for r in rerank_response: metadata = r["meta"] metadata["relevance_score"] = r["score"] doc = Document( page_content=r["text"], metadata=metadata, ) final_results.append(doc) return final_results