Source code for langchain.retrievers.document_compressors.cohere_rerank

from __future__ import annotations

from copy import deepcopy
from typing import Any, Dict, List, Optional, Sequence, Union

from langchain_core._api.deprecation import deprecated
from langchain_core.callbacks.manager import Callbacks
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import Extra, root_validator
from langchain_core.utils import get_from_dict_or_env

from langchain.retrievers.document_compressors.base import BaseDocumentCompressor


[docs]@deprecated( since="0.0.30", removal="0.3.0", alternative_import="langchain_cohere.CohereRerank" ) class CohereRerank(BaseDocumentCompressor): """使用`Cohere Rerank API`的文档压缩器。""" client: Any = None """用于压缩文档的Cohere客户端。""" top_n: Optional[int] = 3 """要返回的文档数量。""" model: str = "rerank-english-v2.0" """用于重新排序的模型。""" cohere_api_key: Optional[str] = None """Cohere API密钥。必须直接指定或通过环境变量COHERE_API_KEY指定。""" user_agent: str = "langchain" """用于发出请求的应用程序标识符。""" class Config: """此pydantic对象的配置。""" extra = Extra.forbid arbitrary_types_allowed = True @root_validator(pre=True) def validate_environment(cls, values: Dict) -> Dict: """验证环境中是否存在API密钥和Python包。""" if not values.get("client"): try: import cohere except ImportError: raise ImportError( "Could not import cohere python package. " "Please install it with `pip install cohere`." ) cohere_api_key = get_from_dict_or_env( values, "cohere_api_key", "COHERE_API_KEY" ) client_name = values.get("user_agent", "langchain") values["client"] = cohere.Client(cohere_api_key, client_name=client_name) return values
[docs] def rerank( self, documents: Sequence[Union[str, Document, dict]], query: str, *, model: Optional[str] = None, top_n: Optional[int] = -1, max_chunks_per_doc: Optional[int] = None, ) -> List[Dict[str, Any]]: """返回一个按照与提供的查询相关性排序的文档列表。 参数: query: 用于重新排序的查询。 documents: 需要重新排序的文档序列。 model: 用于重新排序的模型。默认为self.model。 top_n: 返回结果的数量。如果为None,则返回所有结果。默认为self.top_n。 max_chunks_per_doc: 从一个文档中提取的最大块数。 """ # noqa: E501 if len(documents) == 0: # to avoid empty api call return [] docs = [ doc.page_content if isinstance(doc, Document) else doc for doc in documents ] model = model or self.model top_n = top_n if (top_n is None or top_n > 0) else self.top_n results = self.client.rerank( query=query, documents=docs, model=model, top_n=top_n, max_chunks_per_doc=max_chunks_per_doc, ) if hasattr(results, "results"): results = getattr(results, "results") result_dicts = [] for res in results: result_dicts.append( {"index": res.index, "relevance_score": res.relevance_score} ) return result_dicts
[docs] def compress_documents( self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None, ) -> Sequence[Document]: """使用Cohere的rerank API压缩文档。 参数: documents:需要压缩的文档序列。 query:用于压缩文档的查询。 callbacks:在压缩过程中运行的回调函数。 返回: 压缩后的文档序列。 """ compressed = [] for res in self.rerank(documents, query): doc = documents[res["index"]] doc_copy = Document(doc.page_content, metadata=deepcopy(doc.metadata)) doc_copy.metadata["relevance_score"] = res["relevance_score"] compressed.append(doc_copy) return compressed