Skip to content

Cloudflare workersai

CloudflareEmbedding #

Bases: BaseEmbedding

Cloudflare Workers AI类,用于生成文本嵌入。

该类允许使用Cloudflare Workers AI和BAAI通用嵌入模型生成文本嵌入。

Args: account_id(str):Cloudflare账户ID。 auth_token(str,可选):Cloudflare授权令牌。或者,设置环境变量CLOUDFLARE_AUTH_TOKEN。 model(str):嵌入服务的模型ID。Cloudflare提供不同的嵌入模型,请查看https://developers.cloudflare.com/workers-ai/models/#text-embeddings。默认为"@cf/baai/bge-base-en-v1.5"。 embed_batch_size(int):嵌入生成的批处理大小。Cloudflare当前限制最大为100。默认为llama_index的默认值。

注意: 确保您拥有有效的Cloudflare账户,并可以访问必要的AI服务和模型。账户ID和授权令牌是敏感信息;请适当保护它们。

Source code in llama_index/embeddings/cloudflare_workersai/base.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
class CloudflareEmbedding(BaseEmbedding):
    """Cloudflare Workers AI类,用于生成文本嵌入。

该类允许使用Cloudflare Workers AI和BAAI通用嵌入模型生成文本嵌入。

Args:
account_id(str):Cloudflare账户ID。
auth_token(str,可选):Cloudflare授权令牌。或者,设置环境变量`CLOUDFLARE_AUTH_TOKEN`。
model(str):嵌入服务的模型ID。Cloudflare提供不同的嵌入模型,请查看https://developers.cloudflare.com/workers-ai/models/#text-embeddings。默认为"@cf/baai/bge-base-en-v1.5"。
embed_batch_size(int):嵌入生成的批处理大小。Cloudflare当前限制最大为100。默认为llama_index的默认值。

注意:
确保您拥有有效的Cloudflare账户,并可以访问必要的AI服务和模型。账户ID和授权令牌是敏感信息;请适当保护它们。"""

    account_id: str = Field(default=None, description="The Cloudflare Account ID.")
    auth_token: str = Field(default=None, description="The Cloudflare Auth Token.")
    model: str = Field(
        default="@cf/baai/bge-base-en-v1.5",
        description="The model to use when calling Cloudflare AI API",
    )

    _session: Any = PrivateAttr()

    def __init__(
        self,
        account_id: str,
        auth_token: Optional[str] = None,
        model: str = "@cf/baai/bge-base-en-v1.5",
        embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
        callback_manager: Optional[CallbackManager] = None,
        **kwargs: Any,
    ) -> None:
        super().__init__(
            embed_batch_size=embed_batch_size,
            callback_manager=callback_manager,
            model=model,
            **kwargs,
        )
        self.account_id = account_id
        self.auth_token = get_from_param_or_env(
            "auth_token", auth_token, "CLOUDFLARE_AUTH_TOKEN", ""
        )
        self.model = model
        self._session = requests.Session()
        self._session.headers.update({"Authorization": f"Bearer {self.auth_token}"})

    @classmethod
    def class_name(cls) -> str:
        return "CloudflareEmbedding"

    def _get_query_embedding(self, query: str) -> List[float]:
        """获取查询嵌入。"""
        return self._get_text_embedding(query)

    async def _aget_query_embedding(self, query: str) -> List[float]:
        """_get_query_embedding的异步版本。"""
        return await self._aget_text_embedding(query)

    def _get_text_embedding(self, text: str) -> List[float]:
        """获取文本嵌入。"""
        return self._get_text_embeddings([text])[0]

    async def _aget_text_embedding(self, text: str) -> List[float]:
        """异步获取文本嵌入。"""
        result = await self._aget_text_embeddings([text])
        return result[0]

    def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
        """获取文本嵌入。"""
        response = self._session.post(
            API_URL_TEMPLATE.format(self.account_id, self.model), json={"text": texts}
        ).json()

        if "result" not in response:
            print(response)
            raise RuntimeError("Failed to fetch embeddings")

        return response["result"]["data"]

    async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
        """异步获取文本嵌入。"""
        import aiohttp

        async with aiohttp.ClientSession(trust_env=True) as session:
            headers = {
                "Authorization": f"Bearer {self.auth_token}",
                "Accept-Encoding": "identity",
            }
            async with session.post(
                API_URL_TEMPLATE.format(self.account_id, self.model),
                json={"text": texts},
                headers=headers,
            ) as response:
                resp = await response.json()
                if "result" not in resp:
                    raise RuntimeError("Failed to fetch embeddings asynchronously")

                return resp["result"]["data"]