Skip to content

Deepinfra

DeepInfraEmbeddingModel #

Bases: BaseEmbedding

一个用于访问通过DeepInfra API可用的嵌入模型的包装类。该类允许将DeepInfra嵌入轻松集成到您的项目中,支持文本嵌入的同步和异步检索。

示例: >>> from llama_index.embeddings.deepinfra import DeepInfraEmbeddingModel >>> model = DeepInfraEmbeddingModel() >>> print(model.get_text_embedding("Hello, world!")) [0.1, 0.2, 0.3, ...]

Source code in llama_index/embeddings/deepinfra/base.py
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
class DeepInfraEmbeddingModel(BaseEmbedding):
    """一个用于访问通过DeepInfra API可用的嵌入模型的包装类。该类允许将DeepInfra嵌入轻松集成到您的项目中,支持文本嵌入的同步和异步检索。

Args:
    model_id(str):用于嵌入的模型的标识符。默认为'sentence-transformers/clip-ViT-B-32'。
    normalize(bool):标志,用于在检索后对嵌入进行规范化。默认为False。
    api_token(str):DeepInfra API令牌。如果未提供,则从环境变量'DEEPINFRA_API_TOKEN'中获取令牌。

示例:
    >>> from llama_index.embeddings.deepinfra import DeepInfraEmbeddingModel
    >>> model = DeepInfraEmbeddingModel()
    >>> print(model.get_text_embedding("Hello, world!"))
    [0.1, 0.2, 0.3, ...]"""

    """model_id可以从DeepInfra网站上获取。"""
    _model_id: str = PrivateAttr()
    """标准化标志,用于在检索后标准化嵌入。"""
    _normalize: bool = PrivateAttr()
    """api_token应该从DeepInfra网站上获取。"""
    _api_token: str = PrivateAttr()
    """query_prefix 用于向查询添加前缀。"""
    _query_prefix: str = PrivateAttr()
    """text_prefix 用于向文本添加前缀。"""
    _text_prefix: str = PrivateAttr()

    def __init__(
        self,
        model_id: str = DEFAULT_MODEL_ID,
        normalize: bool = False,
        api_token: str = None,
        callback_manager: Optional[CallbackManager] = None,
        query_prefix: str = "",
        text_prefix: str = "",
        embed_batch_size: int = MAX_BATCH_SIZE,
    ) -> None:
        """
        初始化参数。
        """
        super().__init__(
            callback_manager=callback_manager, embed_batch_size=embed_batch_size
        )

        self._model_id = model_id
        self._normalize = normalize
        self._api_token = api_token or os.getenv(ENV_VARIABLE, None)
        self._query_prefix = query_prefix
        self._text_prefix = text_prefix

    def _post(self, data: List[str]) -> List[List[float]]:
        """向DeepInfra推理API发送POST请求,并返回API响应。
输入数据被分成批次以避免超过最大批处理大小(1024)。

Args:
    data (List[str]): 要嵌入的字符串列表。

Returns:

"""
        url = self.get_url()
        chunked_data = _chunk(data, self.embed_batch_size)
        embeddings = []
        for chunk in chunked_data:
            response = requests.post(
                url,
                json={
                    "inputs": chunk,
                },
                headers={
                    "Authorization": f"Bearer {self._api_token}",
                    "Content-Type": "application/json",
                },
            )
            response.raise_for_status()
            embeddings.extend(response.json()["embeddings"])
        return embeddings

    def get_url(self):
        """
        获取DeepInfra API URL。
        """
        return f"{INFERENCE_URL}/{self._model_id}"

    async def _apost(self, data: List[str]) -> List[List[float]]:
        """向DeepInfra推理API发送POST请求,并返回API响应。
输入数据被分成批次以避免超过最大批处理大小(1024)。

Args:
    data(List[str]):要嵌入的字符串列表。
输出:
    List[float]:API返回的嵌入列表。
"""
        url = self.get_url()
        chunked_data = _chunk(data, self.embed_batch_size)
        embeddings = []
        for chunk in chunked_data:
            async with aiohttp.ClientSession() as session:
                async with session.post(
                    url,
                    json={
                        "inputs": chunk,
                    },
                    headers={
                        "Authorization": f"Bearer {self._api_token}",
                        "Content-Type": "application/json",
                    },
                ) as resp:
                    response = await resp.json()
                    embeddings.extend(response["embeddings"])
        return embeddings

    def _get_query_embedding(self, query: str) -> List[float]:
        """
        获取查询嵌入。
        """
        return self._post(self._add_query_prefix([query]))[0]

    async def _aget_query_embedding(self, query: str) -> List[float]:
        """
        异步获取查询嵌入。
        """
        response = await self._apost(self._add_query_prefix([query]))
        return response[0]

    def _get_query_embeddings(self, queries: List[str]) -> List[List[float]]:
        """
        获取查询嵌入。
        """
        return self._post(self._add_query_prefix(queries))

    async def _aget_query_embeddings(self, queries: List[str]) -> List[List[float]]:
        """
        异步获取查询嵌入。
        """
        return await self._apost(self._add_query_prefix(queries))

    def _get_text_embedding(self, text: str) -> List[float]:
        """
        获取文本嵌入。
        """
        return self._post(self._add_text_prefix([text]))[0]

    async def _aget_text_embedding(self, text: str) -> List[float]:
        """
        异步获取文本嵌入。
        """
        response = await self._apost(self._add_text_prefix([text]))
        return response[0]

    def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
        """
        获取文本嵌入。
        """
        return self._post(self._add_text_prefix(texts))

    async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
        """
        异步获取文本嵌入。
        """
        return await self._apost(self._add_text_prefix(texts))

    def _add_query_prefix(self, queries: List[str]) -> List[str]:
        """
        为查询添加查询前缀。
        """
        return (
            [self._query_prefix + query for query in queries]
            if self._query_prefix
            else queries
        )

    def _add_text_prefix(self, texts: List[str]) -> List[str]:
        """
        给文本添加文本前缀。
        """
        return (
            [self._text_prefix + text for text in texts] if self._text_prefix else texts
        )

get_url #

get_url()

获取DeepInfra API URL。

Source code in llama_index/embeddings/deepinfra/base.py
101
102
103
104
105
def get_url(self):
    """
    获取DeepInfra API URL。
    """
    return f"{INFERENCE_URL}/{self._model_id}"