Source code for langchain_community.embeddings.embaas

from typing import Any, Dict, List, Mapping, Optional

import requests
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from requests.adapters import HTTPAdapter, Retry
from typing_extensions import NotRequired, TypedDict

# Currently supported maximum batch size for embedding requests
MAX_BATCH_SIZE = 256
EMBAAS_API_URL = "https://api.embaas.io/v1/embeddings/"


[docs]class EmbaasEmbeddingsPayload(TypedDict): """用于Embaas嵌入式API的有效载荷。""" model: str texts: List[str] instruction: NotRequired[str]
[docs]class EmbaasEmbeddings(BaseModel, Embeddings): """Embaas的嵌入式服务。 要使用,您应该设置环境变量``EMBAAS_API_KEY``并使用您的API密钥,或将其作为构造函数的命名参数传递。 示例: .. code-block:: python # 使用默认模型和指令进行初始化 from langchain_community.embeddings import EmbaasEmbeddings emb = EmbaasEmbeddings() # 使用自定义模型和指令进行初始化 from langchain_community.embeddings import EmbaasEmbeddings emb_model = "instructor-large" emb_inst = "代表用于检索的维基百科文档" emb = EmbaasEmbeddings( model=emb_model, instruction=emb_inst )""" model: str = "e5-large-v2" """用于嵌入的模型。""" instruction: Optional[str] = None """用于领域特定嵌入的指令。""" api_url: str = EMBAAS_API_URL """embaas嵌入式API的URL。""" embaas_api_key: Optional[SecretStr] = None """请求的最大重试次数""" max_retries: Optional[int] = 3 """请求超时时间(秒)""" timeout: Optional[int] = 30 class Config: """此pydantic对象的配置。""" extra = Extra.forbid @root_validator() def validate_environment(cls, values: Dict) -> Dict: """验证环境中是否存在API密钥和Python包。""" embaas_api_key = convert_to_secret_str( get_from_dict_or_env(values, "embaas_api_key", "EMBAAS_API_KEY") ) values["embaas_api_key"] = embaas_api_key return values @property def _identifying_params(self) -> Mapping[str, Any]: """获取识别参数。""" return {"model": self.model, "instruction": self.instruction} def _generate_payload(self, texts: List[str]) -> EmbaasEmbeddingsPayload: """生成API请求的有效负载。""" payload = EmbaasEmbeddingsPayload(texts=texts, model=self.model) if self.instruction: payload["instruction"] = self.instruction return payload def _handle_request(self, payload: EmbaasEmbeddingsPayload) -> List[List[float]]: """发送请求到Embaas API并处理响应。""" headers = { "Authorization": f"Bearer {self.embaas_api_key.get_secret_value()}", # type: ignore[union-attr] "Content-Type": "application/json", } session = requests.Session() retries = Retry( total=self.max_retries, backoff_factor=0.5, allowed_methods=["POST"], raise_on_status=True, ) session.mount("http://", HTTPAdapter(max_retries=retries)) session.mount("https://", HTTPAdapter(max_retries=retries)) response = session.post( self.api_url, headers=headers, json=payload, timeout=self.timeout, ) parsed_response = response.json() embeddings = [item["embedding"] for item in parsed_response["data"]] return embeddings def _generate_embeddings(self, texts: List[str]) -> List[List[float]]: """使用Embaas API生成嵌入。""" payload = self._generate_payload(texts) try: return self._handle_request(payload) except requests.exceptions.RequestException as e: if e.response is None or not e.response.text: raise ValueError(f"Error raised by embaas embeddings API: {e}") parsed_response = e.response.json() if "message" in parsed_response: raise ValueError( "Validation Error raised by embaas embeddings API:" f"{parsed_response['message']}" ) raise
[docs] def embed_documents(self, texts: List[str]) -> List[List[float]]: """获取文本列表的嵌入。 参数: texts:要获取嵌入的文本列表。 返回: 嵌入列表,每个文本对应一个嵌入。 """ batches = [ texts[i : i + MAX_BATCH_SIZE] for i in range(0, len(texts), MAX_BATCH_SIZE) ] embeddings = [self._generate_embeddings(batch) for batch in batches] # flatten the list of lists into a single list return [embedding for batch in embeddings for embedding in batch]
[docs] def embed_query(self, text: str) -> List[float]: """获取单个文本的嵌入。 参数: text:要获取嵌入的文本。 返回: 嵌入列表。 """ return self.embed_documents([text])[0]