Source code for langchain_community.embeddings.embaas
from typing import Any, Dict, List, Mapping, Optional
import requests
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from requests.adapters import HTTPAdapter, Retry
from typing_extensions import NotRequired, TypedDict
# Currently supported maximum batch size for embedding requests
MAX_BATCH_SIZE = 256
EMBAAS_API_URL = "https://api.embaas.io/v1/embeddings/"
[docs]class EmbaasEmbeddingsPayload(TypedDict):
"""用于Embaas嵌入式API的有效载荷。"""
model: str
texts: List[str]
instruction: NotRequired[str]
[docs]class EmbaasEmbeddings(BaseModel, Embeddings):
"""Embaas的嵌入式服务。
要使用,您应该设置环境变量``EMBAAS_API_KEY``并使用您的API密钥,或将其作为构造函数的命名参数传递。
示例:
.. code-block:: python
# 使用默认模型和指令进行初始化
from langchain_community.embeddings import EmbaasEmbeddings
emb = EmbaasEmbeddings()
# 使用自定义模型和指令进行初始化
from langchain_community.embeddings import EmbaasEmbeddings
emb_model = "instructor-large"
emb_inst = "代表用于检索的维基百科文档"
emb = EmbaasEmbeddings(
model=emb_model,
instruction=emb_inst
)"""
model: str = "e5-large-v2"
"""用于嵌入的模型。"""
instruction: Optional[str] = None
"""用于领域特定嵌入的指令。"""
api_url: str = EMBAAS_API_URL
"""embaas嵌入式API的URL。"""
embaas_api_key: Optional[SecretStr] = None
"""请求的最大重试次数"""
max_retries: Optional[int] = 3
"""请求超时时间(秒)"""
timeout: Optional[int] = 30
class Config:
"""此pydantic对象的配置。"""
extra = Extra.forbid
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""验证环境中是否存在API密钥和Python包。"""
embaas_api_key = convert_to_secret_str(
get_from_dict_or_env(values, "embaas_api_key", "EMBAAS_API_KEY")
)
values["embaas_api_key"] = embaas_api_key
return values
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""获取识别参数。"""
return {"model": self.model, "instruction": self.instruction}
def _generate_payload(self, texts: List[str]) -> EmbaasEmbeddingsPayload:
"""生成API请求的有效负载。"""
payload = EmbaasEmbeddingsPayload(texts=texts, model=self.model)
if self.instruction:
payload["instruction"] = self.instruction
return payload
def _handle_request(self, payload: EmbaasEmbeddingsPayload) -> List[List[float]]:
"""发送请求到Embaas API并处理响应。"""
headers = {
"Authorization": f"Bearer {self.embaas_api_key.get_secret_value()}", # type: ignore[union-attr]
"Content-Type": "application/json",
}
session = requests.Session()
retries = Retry(
total=self.max_retries,
backoff_factor=0.5,
allowed_methods=["POST"],
raise_on_status=True,
)
session.mount("http://", HTTPAdapter(max_retries=retries))
session.mount("https://", HTTPAdapter(max_retries=retries))
response = session.post(
self.api_url,
headers=headers,
json=payload,
timeout=self.timeout,
)
parsed_response = response.json()
embeddings = [item["embedding"] for item in parsed_response["data"]]
return embeddings
def _generate_embeddings(self, texts: List[str]) -> List[List[float]]:
"""使用Embaas API生成嵌入。"""
payload = self._generate_payload(texts)
try:
return self._handle_request(payload)
except requests.exceptions.RequestException as e:
if e.response is None or not e.response.text:
raise ValueError(f"Error raised by embaas embeddings API: {e}")
parsed_response = e.response.json()
if "message" in parsed_response:
raise ValueError(
"Validation Error raised by embaas embeddings API:"
f"{parsed_response['message']}"
)
raise
[docs] def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""获取文本列表的嵌入。
参数:
texts:要获取嵌入的文本列表。
返回:
嵌入列表,每个文本对应一个嵌入。
"""
batches = [
texts[i : i + MAX_BATCH_SIZE] for i in range(0, len(texts), MAX_BATCH_SIZE)
]
embeddings = [self._generate_embeddings(batch) for batch in batches]
# flatten the list of lists into a single list
return [embedding for batch in embeddings for embedding in batch]
[docs] def embed_query(self, text: str) -> List[float]:
"""获取单个文本的嵌入。
参数:
text:要获取嵌入的文本。
返回:
嵌入列表。
"""
return self.embed_documents([text])[0]