Source code for langchain_community.embeddings.gradient_ai
from typing import Any, Dict, List, Optional
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
from langchain_core.utils import get_from_dict_or_env
from packaging.version import parse
__all__ = ["GradientEmbeddings"]
[docs]class GradientEmbeddings(BaseModel, Embeddings):
"""Gradient.ai嵌入模型。
GradientLLM是一个用于与gradient.ai上的嵌入模型进行交互的类。
要使用,请设置环境变量``GRADIENT_ACCESS_TOKEN``为您的API令牌,``GRADIENT_WORKSPACE_ID``为您的gradient工作区,或者作为此类的构造函数的关键字参数提供它们。
示例:
.. code-block:: python
from langchain_community.embeddings import GradientEmbeddings
GradientEmbeddings(
model="bge-large",
gradient_workspace_id="12345614fc0_workspace",
gradient_access_token="gradientai-access_token",
)
"""
model: str
"潜在的gradient.ai模型ID。"
gradient_workspace_id: Optional[str] = None
"梯度.ai工作空间的基础workspace_id。"
gradient_access_token: Optional[str] = None
"""gradient.ai API令牌,可以通过访问https://auth.gradient.ai/select-workspace生成,并在个人资料下拉菜单中选择“访问令牌”。"""
gradient_api_url: str = "https://api.gradient.ai/api"
"""要使用的端点URL。"""
query_prompt_for_retrieval: Optional[str] = None
"""查询预提示"""
client: Any = None #: :meta private:
"""梯度客户端。"""
# LLM call kwargs
class Config:
"""此pydantic对象的配置。"""
extra = Extra.forbid
@root_validator(allow_reuse=True)
def validate_environment(cls, values: Dict) -> Dict:
"""验证环境中是否存在API密钥和Python包。"""
values["gradient_access_token"] = get_from_dict_or_env(
values, "gradient_access_token", "GRADIENT_ACCESS_TOKEN"
)
values["gradient_workspace_id"] = get_from_dict_or_env(
values, "gradient_workspace_id", "GRADIENT_WORKSPACE_ID"
)
values["gradient_api_url"] = get_from_dict_or_env(
values, "gradient_api_url", "GRADIENT_API_URL"
)
try:
import gradientai
except ImportError:
raise ImportError(
'GradientEmbeddings requires `pip install -U "gradientai>=1.4.0"`.'
)
if parse(gradientai.__version__) < parse("1.4.0"):
raise ImportError(
'GradientEmbeddings requires `pip install -U "gradientai>=1.4.0"`.'
)
gradient = gradientai.Gradient(
access_token=values["gradient_access_token"],
workspace_id=values["gradient_workspace_id"],
host=values["gradient_api_url"],
)
values["client"] = gradient.get_embeddings_model(slug=values["model"])
return values
[docs] def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""调用Gradient的嵌入端点。
参数:
texts:要嵌入的文本列表。
返回:
每个文本的嵌入列表。
"""
inputs = [{"input": text} for text in texts]
result = self.client.embed(inputs=inputs).embeddings
return [e.embedding for e in result]
[docs] async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""异步调用Gradient的嵌入端点。
参数:
texts:要嵌入的文本列表。
返回:
每个文本的嵌入列表。
"""
inputs = [{"input": text} for text in texts]
result = (await self.client.aembed(inputs=inputs)).embeddings
return [e.embedding for e in result]
[docs] def embed_query(self, text: str) -> List[float]:
"""调用Gradient的嵌入端点。
参数:
text:要嵌入的文本。
返回:
文本的嵌入。
"""
query = (
f"{self.query_prompt_for_retrieval} {text}"
if self.query_prompt_for_retrieval
else text
)
return self.embed_documents([query])[0]
[docs] async def aembed_query(self, text: str) -> List[float]:
"""异步调用Gradient的嵌入端点。
参数:
text: 要嵌入的文本。
返回:
文本的嵌入。
"""
query = (
f"{self.query_prompt_for_retrieval} {text}"
if self.query_prompt_for_retrieval
else text
)
embeddings = await self.aembed_documents([query])
return embeddings[0]
[docs]class TinyAsyncGradientEmbeddingClient: #: :meta private:
"""已弃用,TinyAsyncGradientEmbeddingClient已被移除。
这个类仅用于向旧版本的langchain_community提供向后兼容性。
未来可能会被完全移除。"""
[docs] def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def]
raise ValueError("Deprecated,TinyAsyncGradientEmbeddingClient was removed.")