Source code for langchain.retrievers.document_compressors.cross_encoder_rerank
from __future__ import annotations
import operator
from typing import Optional, Sequence
from langchain_core.callbacks import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.pydantic_v1 import Extra
from langchain.retrievers.document_compressors.cross_encoder import BaseCrossEncoder
[docs]class CrossEncoderReranker(BaseDocumentCompressor):
"""使用CrossEncoder进行重新排序的文档压缩器。"""
model: BaseCrossEncoder
"""用于计算查询和文档之间相似性得分的CrossEncoder模型。"""
top_n: int = 3
"""要返回的文档数量。"""
class Config:
"""这个pydantic对象的配置。"""
extra = Extra.forbid
arbitrary_types_allowed = True
[docs] def compress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
"""使用CrossEncoder重新对文档进行排名。
参数:
documents: 需要压缩的文档序列。
query: 用于压缩文档的查询。
callbacks: 在压缩过程中运行的回调函数。
返回值:
压缩后的文档序列。
"""
scores = self.model.score([(query, doc.page_content) for doc in documents])
docs_with_scores = list(zip(documents, scores))
result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True)
return [doc for doc, _ in result[: self.top_n]]