Source code for langchain_community.embeddings.sambanova

from typing import Dict, Generator, List

import requests
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, root_validator
from langchain_core.utils import get_from_dict_or_env


[docs]class SambaStudioEmbeddings(BaseModel, Embeddings): """SambaNova嵌入模型。 要使用,您应该设置环境变量``SAMBASTUDIO_EMBEDDINGS_BASE_URL``、``SAMBASTUDIO_EMBEDDINGS_PROJECT_ID``、``SAMBASTUDIO_EMBEDDINGS_ENDPOINT_ID``、``SAMBASTUDIO_EMBEDDINGS_API_KEY``,并将其设置为您个人的sambastudio变量,或将其作为命名参数传递给构造函数。 示例: .. code-block:: python from langchain_community.embeddings import SambaStudioEmbeddings embeddings = SambaStudioEmbeddings(sambastudio_embeddings_base_url=base_url, sambastudio_embeddings_project_id=project_id, sambastudio_embeddings_endpoint_id=endpoint_id, sambastudio_embeddings_api_key=api_key) (或) embeddings = SambaStudioEmbeddings()""" API_BASE_PATH = "/api/predict/nlp/" """用于API使用的基本路径""" sambastudio_embeddings_base_url: str = "" """用于基础URL""" sambastudio_embeddings_project_id: str = "" """在sambastudio上用于模型的项目ID""" sambastudio_embeddings_endpoint_id: str = "" """在SambaStudio上用于模型的端点ID""" sambastudio_embeddings_api_key: str = "" """Sambastudio API密钥""" @root_validator() def validate_environment(cls, values: Dict) -> Dict: """验证环境中是否存在API密钥和Python包。""" values["sambastudio_embeddings_base_url"] = get_from_dict_or_env( values, "sambastudio_embeddings_base_url", "SAMBASTUDIO_EMBEDDINGS_BASE_URL" ) values["sambastudio_embeddings_project_id"] = get_from_dict_or_env( values, "sambastudio_embeddings_project_id", "SAMBASTUDIO_EMBEDDINGS_PROJECT_ID", ) values["sambastudio_embeddings_endpoint_id"] = get_from_dict_or_env( values, "sambastudio_embeddings_endpoint_id", "SAMBASTUDIO_EMBEDDINGS_ENDPOINT_ID", ) values["sambastudio_embeddings_api_key"] = get_from_dict_or_env( values, "sambastudio_embeddings_api_key", "SAMBASTUDIO_EMBEDDINGS_API_KEY" ) return values def _get_full_url(self, path: str) -> str: """返回给定路径的完整API URL。 :param str path: 子路径 :returns: 子路径的完整API URL :rtype: str """ return f"{self.sambastudio_embeddings_base_url}{self.API_BASE_PATH}{path}" def _iterate_over_batches(self, texts: List[str], batch_size: int) -> Generator: """用于在嵌入文档方法中创建批次的生成器 参数: texts(List[str]):要嵌入的字符串列表 batch_size(int,可选):用于嵌入模型的批次大小。 将取决于所使用的RDU端点。 产出: List[str]:大小为批次大小的字符串列表(批次) """ for i in range(0, len(texts), batch_size): yield texts[i : i + batch_size]
[docs] def embed_documents( self, texts: List[str], batch_size: int = 32 ) -> List[List[float]]: """返回给定句子的嵌入列表。 参数: texts(`List[str]`):要编码的文本列表 batch_size(`int`):编码的批处理大小 返回: `List[np.ndarray]` 或 `List[tensor]`:给定句子的嵌入列表 """ http_session = requests.Session() url = self._get_full_url( f"{self.sambastudio_embeddings_project_id}/{self.sambastudio_embeddings_endpoint_id}" ) embeddings = [] for batch in self._iterate_over_batches(texts, batch_size): data = {"inputs": batch} response = http_session.post( url, headers={"key": self.sambastudio_embeddings_api_key}, json=data, ) embedding = response.json()["data"] embeddings.extend(embedding) return embeddings
[docs] def embed_query(self, text: str) -> List[float]: """返回给定句子的嵌入列表。 参数: sentences(`List[str]`):要编码的句子列表 返回: `List[np.ndarray]` 或 `List[tensor]`:给定句子的嵌入列表 """ http_session = requests.Session() url = self._get_full_url( f"{self.sambastudio_embeddings_project_id}/{self.sambastudio_embeddings_endpoint_id}" ) data = {"inputs": [text]} response = http_session.post( url, headers={"key": self.sambastudio_embeddings_api_key}, json=data, ) embedding = response.json()["data"][0] return embedding