Source code for langchain_community.embeddings.sambanova
from typing import Dict, Generator, List
import requests
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, root_validator
from langchain_core.utils import get_from_dict_or_env
[docs]class SambaStudioEmbeddings(BaseModel, Embeddings):
"""SambaNova嵌入模型。
要使用,您应该设置环境变量``SAMBASTUDIO_EMBEDDINGS_BASE_URL``、``SAMBASTUDIO_EMBEDDINGS_PROJECT_ID``、``SAMBASTUDIO_EMBEDDINGS_ENDPOINT_ID``、``SAMBASTUDIO_EMBEDDINGS_API_KEY``,并将其设置为您个人的sambastudio变量,或将其作为命名参数传递给构造函数。
示例:
.. code-block:: python
from langchain_community.embeddings import SambaStudioEmbeddings
embeddings = SambaStudioEmbeddings(sambastudio_embeddings_base_url=base_url,
sambastudio_embeddings_project_id=project_id,
sambastudio_embeddings_endpoint_id=endpoint_id,
sambastudio_embeddings_api_key=api_key)
(或)
embeddings = SambaStudioEmbeddings()"""
API_BASE_PATH = "/api/predict/nlp/"
"""用于API使用的基本路径"""
sambastudio_embeddings_base_url: str = ""
"""用于基础URL"""
sambastudio_embeddings_project_id: str = ""
"""在sambastudio上用于模型的项目ID"""
sambastudio_embeddings_endpoint_id: str = ""
"""在SambaStudio上用于模型的端点ID"""
sambastudio_embeddings_api_key: str = ""
"""Sambastudio API密钥"""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""验证环境中是否存在API密钥和Python包。"""
values["sambastudio_embeddings_base_url"] = get_from_dict_or_env(
values, "sambastudio_embeddings_base_url", "SAMBASTUDIO_EMBEDDINGS_BASE_URL"
)
values["sambastudio_embeddings_project_id"] = get_from_dict_or_env(
values,
"sambastudio_embeddings_project_id",
"SAMBASTUDIO_EMBEDDINGS_PROJECT_ID",
)
values["sambastudio_embeddings_endpoint_id"] = get_from_dict_or_env(
values,
"sambastudio_embeddings_endpoint_id",
"SAMBASTUDIO_EMBEDDINGS_ENDPOINT_ID",
)
values["sambastudio_embeddings_api_key"] = get_from_dict_or_env(
values, "sambastudio_embeddings_api_key", "SAMBASTUDIO_EMBEDDINGS_API_KEY"
)
return values
def _get_full_url(self, path: str) -> str:
"""返回给定路径的完整API URL。
:param str path: 子路径
:returns: 子路径的完整API URL
:rtype: str
"""
return f"{self.sambastudio_embeddings_base_url}{self.API_BASE_PATH}{path}"
def _iterate_over_batches(self, texts: List[str], batch_size: int) -> Generator:
"""用于在嵌入文档方法中创建批次的生成器
参数:
texts(List[str]):要嵌入的字符串列表
batch_size(int,可选):用于嵌入模型的批次大小。
将取决于所使用的RDU端点。
产出:
List[str]:大小为批次大小的字符串列表(批次)
"""
for i in range(0, len(texts), batch_size):
yield texts[i : i + batch_size]
[docs] def embed_documents(
self, texts: List[str], batch_size: int = 32
) -> List[List[float]]:
"""返回给定句子的嵌入列表。
参数:
texts(`List[str]`):要编码的文本列表
batch_size(`int`):编码的批处理大小
返回:
`List[np.ndarray]` 或 `List[tensor]`:给定句子的嵌入列表
"""
http_session = requests.Session()
url = self._get_full_url(
f"{self.sambastudio_embeddings_project_id}/{self.sambastudio_embeddings_endpoint_id}"
)
embeddings = []
for batch in self._iterate_over_batches(texts, batch_size):
data = {"inputs": batch}
response = http_session.post(
url,
headers={"key": self.sambastudio_embeddings_api_key},
json=data,
)
embedding = response.json()["data"]
embeddings.extend(embedding)
return embeddings
[docs] def embed_query(self, text: str) -> List[float]:
"""返回给定句子的嵌入列表。
参数:
sentences(`List[str]`):要编码的句子列表
返回:
`List[np.ndarray]` 或 `List[tensor]`:给定句子的嵌入列表
"""
http_session = requests.Session()
url = self._get_full_url(
f"{self.sambastudio_embeddings_project_id}/{self.sambastudio_embeddings_endpoint_id}"
)
data = {"inputs": [text]}
response = http_session.post(
url,
headers={"key": self.sambastudio_embeddings_api_key},
json=data,
)
embedding = response.json()["data"][0]
return embedding