Source code for langchain_community.embeddings.deepinfra

from typing import Any, Dict, List, Mapping, Optional

import requests
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
from langchain_core.utils import get_from_dict_or_env

DEFAULT_MODEL_ID = "sentence-transformers/clip-ViT-B-32"
MAX_BATCH_SIZE = 1024


[docs]class DeepInfraEmbeddings(BaseModel, Embeddings): """Deep Infra的嵌入推理服务。 要使用,您应该设置环境变量``DEEPINFRA_API_TOKEN``为您的API令牌,或将其作为构造函数的命名参数传递。 有多个可用的嵌入模型,请参见https://deepinfra.com/models?type=embeddings。 示例: .. code-block:: python from langchain_community.embeddings import DeepInfraEmbeddings deepinfra_emb = DeepInfraEmbeddings( model_id="sentence-transformers/clip-ViT-B-32", deepinfra_api_token="my-api-key" ) r1 = deepinfra_emb.embed_documents( [ "Alpha is the first letter of Greek alphabet", "Beta is the second letter of Greek alphabet", ] ) r2 = deepinfra_emb.embed_query( "What is the second letter of Greek alphabet" )""" model_id: str = DEFAULT_MODEL_ID """要使用的嵌入模型。""" normalize: bool = False """是否规范化计算出的嵌入向量""" embed_instruction: str = "passage: " """用于嵌入文档的指令。""" query_instruction: str = "query: " """用于嵌入查询的指令。""" model_kwargs: Optional[dict] = None """其他模型关键字参数""" deepinfra_api_token: Optional[str] = None """Deep Infra的API令牌。如果未提供,则从环境变量'DEEPINFRA_API_TOKEN'中获取令牌。""" batch_size: int = MAX_BATCH_SIZE """嵌入请求的批量大小。""" class Config: """此pydantic对象的配置。""" extra = Extra.forbid @root_validator() def validate_environment(cls, values: Dict) -> Dict: """验证环境中是否存在API密钥和Python包。""" deepinfra_api_token = get_from_dict_or_env( values, "deepinfra_api_token", "DEEPINFRA_API_TOKEN" ) values["deepinfra_api_token"] = deepinfra_api_token return values @property def _identifying_params(self) -> Mapping[str, Any]: """获取识别参数。""" return {"model_id": self.model_id} def _embed(self, input: List[str]) -> List[List[float]]: _model_kwargs = self.model_kwargs or {} # HTTP headers for authorization headers = { "Authorization": f"bearer {self.deepinfra_api_token}", "Content-Type": "application/json", } # send request try: res = requests.post( f"https://api.deepinfra.com/v1/inference/{self.model_id}", headers=headers, json={"inputs": input, "normalize": self.normalize, **_model_kwargs}, ) except requests.exceptions.RequestException as e: raise ValueError(f"Error raised by inference endpoint: {e}") if res.status_code != 200: raise ValueError( "Error raised by inference API HTTP code: %s, %s" % (res.status_code, res.text) ) try: t = res.json() embeddings = t["embeddings"] except requests.exceptions.JSONDecodeError as e: raise ValueError( f"Error raised by inference API: {e}.\nResponse: {res.text}" ) return embeddings
[docs] def embed_documents(self, texts: List[str]) -> List[List[float]]: """使用Deep Infra部署的嵌入模型嵌入文档。 对于较大的批次,输入文本列表被分成较小的批次,以避免超过最大请求大小。 参数: texts:要嵌入的文本列表。 返回: 每个文本的嵌入列表。 """ embeddings = [] instruction_pairs = [f"{self.embed_instruction}{text}" for text in texts] chunks = [ instruction_pairs[i : i + self.batch_size] for i in range(0, len(instruction_pairs), self.batch_size) ] for chunk in chunks: embeddings += self._embed(chunk) return embeddings
[docs] def embed_query(self, text: str) -> List[float]: """使用Deep Infra部署的嵌入模型嵌入一个查询。 参数: text: 要嵌入的文本。 返回: 文本的嵌入。 """ instruction_pair = f"{self.query_instruction}{text}" embedding = self._embed([instruction_pair])[0] return embedding