Source code for langchain_community.cross_encoders.huggingface

from typing import Any, Dict, List, Tuple

from langchain_core.pydantic_v1 import BaseModel, Extra, Field

from langchain_community.cross_encoders.base import BaseCrossEncoder

DEFAULT_MODEL_NAME = "BAAI/bge-reranker-base"


[docs]class HuggingFaceCrossEncoder(BaseModel, BaseCrossEncoder): """HuggingFace跨编码器模型。 示例: .. code-block:: python from langchain_community.cross_encoders import HuggingFaceCrossEncoder model_name = "BAAI/bge-reranker-base" model_kwargs = {'device': 'cpu'} hf = HuggingFaceCrossEncoder( model_name=model_name, model_kwargs=model_kwargs )""" client: Any #: :meta private: model_name: str = DEFAULT_MODEL_NAME """要使用的模型名称。""" model_kwargs: Dict[str, Any] = Field(default_factory=dict) """传递给模型的关键字参数。""" def __init__(self, **kwargs: Any): """初始化sentence_transformer。""" super().__init__(**kwargs) try: import sentence_transformers except ImportError as exc: raise ImportError( "Could not import sentence_transformers python package. " "Please install it with `pip install sentence-transformers`." ) from exc self.client = sentence_transformers.CrossEncoder( self.model_name, **self.model_kwargs ) class Config: """此pydantic对象的配置。""" extra = Extra.forbid
[docs] def score(self, text_pairs: List[Tuple[str, str]]) -> List[float]: """使用HuggingFace transformer模型计算相似性分数。 参数: text_pairs: 需要计算相似性分数的文本对列表。 返回: 每对文本对应的分数列表。 """ scores = self.client.predict(text_pairs) return scores