24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124 | class GradientEmbedding(BaseEmbedding):
"""GradientAI嵌入模型。
该类提供了使用在Gradient AI中部署的模型生成嵌入的接口。在初始化时,需要提供集群中部署的模型的model_id。
注意:
需要在PYTHONPATH中可用`gradientai`包。可以使用`pip install gradientai`进行安装。"""
embed_batch_size: int = Field(default=GRADIENT_EMBED_BATCH_SIZE, gt=0)
_gradient: Any = PrivateAttr()
_model: Any = PrivateAttr()
@classmethod
def class_name(cls) -> str:
return "GradientEmbedding"
def __init__(
self,
*,
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
gradient_model_slug: str,
gradient_access_token: Optional[str] = None,
gradient_workspace_id: Optional[str] = None,
gradient_host: Optional[str] = None,
**kwargs: Any,
):
"""初始化GradientEmbedding类。
在初始化过程中,导入了`gradientai`包。使用访问令牌、工作区ID和模型的slug,从Gradient AI获取模型并准备好使用。
Args:
embed_batch_size (int, optional): 用于生成嵌入的批处理大小。默认为10,必须大于0且小于等于100。
gradient_model_slug (str): Gradient AI帐户中模型的模型slug。
gradient_access_token (str, optional): Gradient AI帐户的访问令牌,如果为`None`,则从环境变量`GRADIENT_ACCESS_TOKEN`中读取。
gradient_workspace_id (str, optional): Gradient AI帐户的工作区ID,如果为`None`,则从环境变量`GRADIENT_WORKSPACE_ID`中读取。
gradient_host (str, optional): Gradient AI API的主机。默认为None,表示使用默认主机。
引发:
ImportError: 如果PYTHONPATH中找不到`gradientai`包。
ValueError: 如果无法从Gradient AI获取模型。
"""
if embed_batch_size <= 0:
raise ValueError(f"Embed batch size {embed_batch_size} must be > 0.")
self._gradient = gradientai.Gradient(
access_token=gradient_access_token,
workspace_id=gradient_workspace_id,
host=gradient_host,
)
try:
self._model = self._gradient.get_embeddings_model(slug=gradient_model_slug)
except gradientai.openapi.client.exceptions.UnauthorizedException as e:
logger.error(f"Error while loading model {gradient_model_slug}.")
self._gradient.close()
raise ValueError("Unable to fetch the requested embeddings model") from e
super().__init__(
embed_batch_size=embed_batch_size, model_name=gradient_model_slug, **kwargs
)
async def _aget_text_embeddings(self, texts: List[str]) -> List[Embedding]:
"""
在异步方式下嵌入输入的文本序列。
"""
inputs = [{"input": text} for text in texts]
result = await self._model.aembed(inputs=inputs).embeddings
return [e.embedding for e in result]
def _get_text_embeddings(self, texts: List[str]) -> List[Embedding]:
"""
嵌入输入的文本序列。
"""
inputs = [{"input": text} for text in texts]
result = self._model.embed(inputs=inputs).embeddings
return [e.embedding for e in result]
def _get_text_embedding(self, text: str) -> Embedding:
"""使用单个文本输入的_get_text_embeddings()的别名。"""
return self._get_text_embeddings([text])[0]
async def _aget_text_embedding(self, text: str) -> Embedding:
"""使用单个文本输入的_aget_text_embeddings()的别名。"""
embedding = await self._aget_text_embeddings([text])
return embedding[0]
async def _aget_query_embedding(self, query: str) -> Embedding:
embedding = await self._aget_text_embeddings(
[f"{QUERY_INSTRUCTION_FOR_RETRIEVAL} {query}"]
)
return embedding[0]
def _get_query_embedding(self, query: str) -> Embedding:
return self._get_text_embeddings(
[f"{QUERY_INSTRUCTION_FOR_RETRIEVAL} {query}"]
)[0]
|