11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110 | class LLMRailsEmbedding(BaseEmbedding):
"""LLMRails嵌入模型。
这个类提供了一个接口,用于使用部署在LLMRails集群中的模型生成嵌入。它需要模型在集群中的model_id和您可以从https://console.llmrails.com/api-keys获取的api密钥。"""
model_id: str
api_key: str
session: requests.Session
@classmethod
def class_name(self) -> str:
return "LLMRailsEmbedding"
def __init__(
self,
api_key: str,
model_id: str = "embedding-english-v1", # or embedding-multi-v1
**kwargs: Any,
):
retry = Retry(
total=3,
connect=3,
read=2,
allowed_methods=["POST"],
backoff_factor=2,
status_forcelist=[502, 503, 504],
)
session = requests.Session()
session.mount("https://api.llmrails.com", HTTPAdapter(max_retries=retry))
session.headers = {"X-API-KEY": api_key}
super().__init__(model_id=model_id, api_key=api_key, session=session, **kwargs)
def _get_embedding(self, text: str) -> List[float]:
"""为单个查询文本生成嵌入。
Args:
text(str):要生成嵌入的查询文本。
Returns:
List[float]:输入查询文本的嵌入。
"""
try:
response = self.session.post(
"https://api.llmrails.com/v1/embeddings",
json={"input": [text], "model": self.model_id},
)
response.raise_for_status()
return response.json()["data"][0]["embedding"]
except requests.exceptions.HTTPError as e:
logger.error(f"Error while embedding text {e}.")
raise ValueError(f"Unable to embed given text {e}")
async def _aget_embedding(self, text: str) -> List[float]:
"""为单个查询文本生成嵌入。
Args:
text(str):要生成嵌入的查询文本。
Returns:
List[float]:输入查询文本的嵌入。
"""
try:
import httpx
except ImportError:
raise ImportError(
"The httpx library is required to use the async version of "
"this function. Install it with `pip install httpx`."
)
try:
async with httpx.AsyncClient() as client:
response = await client.post(
"https://api.llmrails.com/v1/embeddings",
headers={"X-API-KEY": self.api_key},
json={"input": [text], "model": self.model_id},
)
response.raise_for_status()
return response.json()["data"][0]["embedding"]
except httpx._exceptions.HTTPError as e:
logger.error(f"Error while embedding text {e}.")
raise ValueError(f"Unable to embed given text {e}")
def _get_text_embedding(self, text: str) -> List[float]:
return self._get_embedding(text)
def _get_query_embedding(self, query: str) -> List[float]:
return self._get_embedding(query)
async def _aget_query_embedding(self, query: str) -> List[float]:
return await self._aget_embedding(query)
async def _aget_text_embedding(self, query: str) -> List[float]:
return await self._aget_embedding(query)
|