Bases: BaseEmbedding
Cloudflare Workers AI类,用于生成文本嵌入。
该类允许使用Cloudflare Workers AI和BAAI通用嵌入模型生成文本嵌入。
Args:
account_id(str):Cloudflare账户ID。
auth_token(str,可选):Cloudflare授权令牌。或者,设置环境变量CLOUDFLARE_AUTH_TOKEN
。
model(str):嵌入服务的模型ID。Cloudflare提供不同的嵌入模型,请查看https://developers.cloudflare.com/workers-ai/models/#text-embeddings。默认为"@cf/baai/bge-base-en-v1.5"。
embed_batch_size(int):嵌入生成的批处理大小。Cloudflare当前限制最大为100。默认为llama_index的默认值。
注意:
确保您拥有有效的Cloudflare账户,并可以访问必要的AI服务和模型。账户ID和授权令牌是敏感信息;请适当保护它们。
Source code in llama_index/embeddings/cloudflare_workersai/base.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115 | class CloudflareEmbedding(BaseEmbedding):
"""Cloudflare Workers AI类,用于生成文本嵌入。
该类允许使用Cloudflare Workers AI和BAAI通用嵌入模型生成文本嵌入。
Args:
account_id(str):Cloudflare账户ID。
auth_token(str,可选):Cloudflare授权令牌。或者,设置环境变量`CLOUDFLARE_AUTH_TOKEN`。
model(str):嵌入服务的模型ID。Cloudflare提供不同的嵌入模型,请查看https://developers.cloudflare.com/workers-ai/models/#text-embeddings。默认为"@cf/baai/bge-base-en-v1.5"。
embed_batch_size(int):嵌入生成的批处理大小。Cloudflare当前限制最大为100。默认为llama_index的默认值。
注意:
确保您拥有有效的Cloudflare账户,并可以访问必要的AI服务和模型。账户ID和授权令牌是敏感信息;请适当保护它们。"""
account_id: str = Field(default=None, description="The Cloudflare Account ID.")
auth_token: str = Field(default=None, description="The Cloudflare Auth Token.")
model: str = Field(
default="@cf/baai/bge-base-en-v1.5",
description="The model to use when calling Cloudflare AI API",
)
_session: Any = PrivateAttr()
def __init__(
self,
account_id: str,
auth_token: Optional[str] = None,
model: str = "@cf/baai/bge-base-en-v1.5",
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
callback_manager: Optional[CallbackManager] = None,
**kwargs: Any,
) -> None:
super().__init__(
embed_batch_size=embed_batch_size,
callback_manager=callback_manager,
model=model,
**kwargs,
)
self.account_id = account_id
self.auth_token = get_from_param_or_env(
"auth_token", auth_token, "CLOUDFLARE_AUTH_TOKEN", ""
)
self.model = model
self._session = requests.Session()
self._session.headers.update({"Authorization": f"Bearer {self.auth_token}"})
@classmethod
def class_name(cls) -> str:
return "CloudflareEmbedding"
def _get_query_embedding(self, query: str) -> List[float]:
"""获取查询嵌入。"""
return self._get_text_embedding(query)
async def _aget_query_embedding(self, query: str) -> List[float]:
"""_get_query_embedding的异步版本。"""
return await self._aget_text_embedding(query)
def _get_text_embedding(self, text: str) -> List[float]:
"""获取文本嵌入。"""
return self._get_text_embeddings([text])[0]
async def _aget_text_embedding(self, text: str) -> List[float]:
"""异步获取文本嵌入。"""
result = await self._aget_text_embeddings([text])
return result[0]
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""获取文本嵌入。"""
response = self._session.post(
API_URL_TEMPLATE.format(self.account_id, self.model), json={"text": texts}
).json()
if "result" not in response:
print(response)
raise RuntimeError("Failed to fetch embeddings")
return response["result"]["data"]
async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""异步获取文本嵌入。"""
import aiohttp
async with aiohttp.ClientSession(trust_env=True) as session:
headers = {
"Authorization": f"Bearer {self.auth_token}",
"Accept-Encoding": "identity",
}
async with session.post(
API_URL_TEMPLATE.format(self.account_id, self.model),
json={"text": texts},
headers=headers,
) as response:
resp = await response.json()
if "result" not in resp:
raise RuntimeError("Failed to fetch embeddings asynchronously")
return resp["result"]["data"]
|