Source code for langchain_community.embeddings.gradient_ai

from typing import Any, Dict, List, Optional

from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
from langchain_core.utils import get_from_dict_or_env
from packaging.version import parse

__all__ = ["GradientEmbeddings"]


[docs]class GradientEmbeddings(BaseModel, Embeddings): """Gradient.ai嵌入模型。 GradientLLM是一个用于与gradient.ai上的嵌入模型进行交互的类。 要使用,请设置环境变量``GRADIENT_ACCESS_TOKEN``为您的API令牌,``GRADIENT_WORKSPACE_ID``为您的gradient工作区,或者作为此类的构造函数的关键字参数提供它们。 示例: .. code-block:: python from langchain_community.embeddings import GradientEmbeddings GradientEmbeddings( model="bge-large", gradient_workspace_id="12345614fc0_workspace", gradient_access_token="gradientai-access_token", ) """ model: str "潜在的gradient.ai模型ID。" gradient_workspace_id: Optional[str] = None "梯度.ai工作空间的基础workspace_id。" gradient_access_token: Optional[str] = None """gradient.ai API令牌,可以通过访问https://auth.gradient.ai/select-workspace生成,并在个人资料下拉菜单中选择“访问令牌”。""" gradient_api_url: str = "https://api.gradient.ai/api" """要使用的端点URL。""" query_prompt_for_retrieval: Optional[str] = None """查询预提示""" client: Any = None #: :meta private: """梯度客户端。""" # LLM call kwargs class Config: """此pydantic对象的配置。""" extra = Extra.forbid @root_validator(allow_reuse=True) def validate_environment(cls, values: Dict) -> Dict: """验证环境中是否存在API密钥和Python包。""" values["gradient_access_token"] = get_from_dict_or_env( values, "gradient_access_token", "GRADIENT_ACCESS_TOKEN" ) values["gradient_workspace_id"] = get_from_dict_or_env( values, "gradient_workspace_id", "GRADIENT_WORKSPACE_ID" ) values["gradient_api_url"] = get_from_dict_or_env( values, "gradient_api_url", "GRADIENT_API_URL" ) try: import gradientai except ImportError: raise ImportError( 'GradientEmbeddings requires `pip install -U "gradientai>=1.4.0"`.' ) if parse(gradientai.__version__) < parse("1.4.0"): raise ImportError( 'GradientEmbeddings requires `pip install -U "gradientai>=1.4.0"`.' ) gradient = gradientai.Gradient( access_token=values["gradient_access_token"], workspace_id=values["gradient_workspace_id"], host=values["gradient_api_url"], ) values["client"] = gradient.get_embeddings_model(slug=values["model"]) return values
[docs] def embed_documents(self, texts: List[str]) -> List[List[float]]: """调用Gradient的嵌入端点。 参数: texts:要嵌入的文本列表。 返回: 每个文本的嵌入列表。 """ inputs = [{"input": text} for text in texts] result = self.client.embed(inputs=inputs).embeddings return [e.embedding for e in result]
[docs] async def aembed_documents(self, texts: List[str]) -> List[List[float]]: """异步调用Gradient的嵌入端点。 参数: texts:要嵌入的文本列表。 返回: 每个文本的嵌入列表。 """ inputs = [{"input": text} for text in texts] result = (await self.client.aembed(inputs=inputs)).embeddings return [e.embedding for e in result]
[docs] def embed_query(self, text: str) -> List[float]: """调用Gradient的嵌入端点。 参数: text:要嵌入的文本。 返回: 文本的嵌入。 """ query = ( f"{self.query_prompt_for_retrieval} {text}" if self.query_prompt_for_retrieval else text ) return self.embed_documents([query])[0]
[docs] async def aembed_query(self, text: str) -> List[float]: """异步调用Gradient的嵌入端点。 参数: text: 要嵌入的文本。 返回: 文本的嵌入。 """ query = ( f"{self.query_prompt_for_retrieval} {text}" if self.query_prompt_for_retrieval else text ) embeddings = await self.aembed_documents([query]) return embeddings[0]
[docs]class TinyAsyncGradientEmbeddingClient: #: :meta private: """已弃用,TinyAsyncGradientEmbeddingClient已被移除。 这个类仅用于向旧版本的langchain_community提供向后兼容性。 未来可能会被完全移除。"""
[docs] def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] raise ValueError("Deprecated,TinyAsyncGradientEmbeddingClient was removed.")