"""根据MIT许可证编写,Michael Feil 2023。"""
import asyncio
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Callable, Dict, List, Optional, Tuple
import aiohttp
import numpy as np
import requests
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
from langchain_core.utils import get_from_dict_or_env
__all__ = ["InfinityEmbeddings"]
[docs]class InfinityEmbeddings(BaseModel, Embeddings):
"""自托管的嵌入模型,用于 `infinity` 包。
请参阅 https://github.com/michaelfeil/infinity
这也适用于文本嵌入推理和其他
自托管的openai兼容服务器。
Infinity 是一个用于与 https://github.com/michaelfeil/infinity 上的嵌入模型进行交互的包
示例:
.. code-block:: python
from langchain_community.embeddings import InfinityEmbeddings
InfinityEmbeddings(
model="BAAI/bge-small",
infinity_api_url="http://localhost:7997",
)
"""
model: str
"基础的Infinity模型id。"
infinity_api_url: str = "http://localhost:7997"
"""要使用的端点URL。"""
client: Any = None #: :meta private:
"""无限客户端。"""
# LLM call kwargs
class Config:
"""此pydantic对象的配置。"""
extra = Extra.forbid
@root_validator(allow_reuse=True)
def validate_environment(cls, values: Dict) -> Dict:
"""验证环境中是否存在API密钥和Python包。"""
values["infinity_api_url"] = get_from_dict_or_env(
values, "infinity_api_url", "INFINITY_API_URL"
)
values["client"] = TinyAsyncOpenAIInfinityEmbeddingClient(
host=values["infinity_api_url"],
)
return values
[docs] def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""调用Infinity的嵌入端点。
参数:
texts:要嵌入的文本列表。
返回:
每个文本的嵌入列表。
"""
embeddings = self.client.embed(
model=self.model,
texts=texts,
)
return embeddings
[docs] async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""异步调用到Infinity的嵌入端点。
参数:
texts:要嵌入的文本列表。
返回:
每个文本的嵌入列表。
"""
embeddings = await self.client.aembed(
model=self.model,
texts=texts,
)
return embeddings
[docs] def embed_query(self, text: str) -> List[float]:
"""调用Infinity的嵌入端点。
参数:
text:要嵌入的文本。
返回:
文本的嵌入。
"""
return self.embed_documents([text])[0]
[docs] async def aembed_query(self, text: str) -> List[float]:
"""异步调用到Infinity的嵌入端点。
参数:
text:要嵌入的文本。
返回:
文本的嵌入结果。
"""
embeddings = await self.aembed_documents([text])
return embeddings[0]
[docs]class TinyAsyncOpenAIInfinityEmbeddingClient: #: :meta private:
"""辅助工具,用于嵌入Infinity。
这不是Langchain稳定API的一部分,不建议直接使用。
示例:
```python
mini_client = TinyAsyncInfinityEmbeddingClient(
)
embeds = mini_client.embed(
model="BAAI/bge-small",
text=["doc1", "doc2"]
)
# 或者
embeds = await mini_client.aembed(
model="BAAI/bge-small",
text=["doc1", "doc2"]
)
```"""
[docs] def __init__(
self,
host: str = "http://localhost:7797/v1",
aiosession: Optional[aiohttp.ClientSession] = None,
) -> None:
self.host = host
self.aiosession = aiosession
if self.host is None or len(self.host) < 3:
raise ValueError(" param `host` must be set to a valid url")
self._batch_size = 128
@staticmethod
def _permute(
texts: List[str], sorter: Callable = len
) -> Tuple[List[str], Callable]:
"""将文本按升序排序,并返回一个lambda表达式,用于对相同长度的列表进行排序
https://github.com/UKPLab/sentence-transformers/blob/
c5f93f70eca933c78695c5bc686ceda59651ae3b/sentence_transformers/SentenceTransformer.py#L156
参数:
texts (List[str]): _描述_
sorter (Callable, optional): _描述_. 默认为 len.
返回:
Tuple[List[str], Callable]: _描述_
示例:
```
texts = ["one","three","four"]
perm_texts, undo = self._permute(texts)
texts == undo(perm_texts)
```
"""
if len(texts) == 1:
# special case query
return texts, lambda t: t
length_sorted_idx = np.argsort([-sorter(sen) for sen in texts])
texts_sorted = [texts[idx] for idx in length_sorted_idx]
return texts_sorted, lambda unsorted_embeddings: [ # noqa E731
unsorted_embeddings[idx] for idx in np.argsort(length_sorted_idx)
]
def _batch(self, texts: List[str]) -> List[List[str]]:
"""将文本部分的列表分成大小为`self._batch_size`的批次
在对向量数据库进行编码时,
参数:
texts (List[str]): 句子列表
self._batch_size (int, optional): 单个请求的最大批次大小。
返回:
List[List[str]]: 句子列表的批次
"""
if len(texts) == 1:
# special case query
return [texts]
batches = []
for start_index in range(0, len(texts), self._batch_size):
batches.append(texts[start_index : start_index + self._batch_size])
return batches
@staticmethod
def _unbatch(batch_of_texts: List[List[Any]]) -> List[Any]:
if len(batch_of_texts) == 1 and len(batch_of_texts[0]) == 1:
# special case query
return batch_of_texts[0]
texts = []
for sublist in batch_of_texts:
texts.extend(sublist)
return texts
def _kwargs_post_request(self, model: str, texts: List[str]) -> Dict[str, Any]:
"""构建用于同步的Post请求的kwargs
参数:
model (str): _描述_
texts (List[str]): _描述_
返回:
Dict[str, Collection[str]]: _描述_
"""
return dict(
url=f"{self.host}/embeddings",
headers={
# "accept": "application/json",
"content-type": "application/json",
},
json=dict(
input=texts,
model=model,
),
)
def _sync_request_embed(
self, model: str, batch_texts: List[str]
) -> List[List[float]]:
response = requests.post(
**self._kwargs_post_request(model=model, texts=batch_texts)
)
if response.status_code != 200:
raise Exception(
f"Infinity returned an unexpected response with status "
f"{response.status_code}: {response.text}"
)
return [e["embedding"] for e in response.json()["data"]]
[docs] def embed(self, model: str, texts: List[str]) -> List[List[float]]:
"""调用模型的嵌入
参数:
model (str): 要嵌入的模型
texts (List[str]): 要嵌入的句子列表
返回:
List[List[float]]: 每个句子的向量列表
"""
perm_texts, unpermute_func = self._permute(texts)
perm_texts_batched = self._batch(perm_texts)
# Request
map_args = (
self._sync_request_embed,
[model] * len(perm_texts_batched),
perm_texts_batched,
)
if len(perm_texts_batched) == 1:
embeddings_batch_perm = list(map(*map_args))
else:
with ThreadPoolExecutor(32) as p:
embeddings_batch_perm = list(p.map(*map_args))
embeddings_perm = self._unbatch(embeddings_batch_perm)
embeddings = unpermute_func(embeddings_perm)
return embeddings
async def _async_request(
self, session: aiohttp.ClientSession, kwargs: Dict[str, Any]
) -> List[List[float]]:
async with session.post(**kwargs) as response:
if response.status != 200:
raise Exception(
f"Infinity returned an unexpected response with status "
f"{response.status}: {response.text}"
)
embedding = (await response.json())["embeddings"]
return [e["embedding"] for e in embedding]
[docs] async def aembed(self, model: str, texts: List[str]) -> List[List[float]]:
"""调用模型的嵌入,异步方法
参数:
model (str): 嵌入模型
texts (List[str]): 要嵌入的句子列表
返回:
List[List[float]]: 每个句子的向量列表
"""
perm_texts, unpermute_func = self._permute(texts)
perm_texts_batched = self._batch(perm_texts)
# Request
if self.aiosession is None:
self.aiosession = aiohttp.ClientSession(
trust_env=True, connector=aiohttp.TCPConnector(limit=32)
)
async with self.aiosession as session:
embeddings_batch_perm = await asyncio.gather(
*[
self._async_request(
session=session,
**self._kwargs_post_request(model=model, texts=t),
)
for t in perm_texts_batched
]
)
embeddings_perm = self._unbatch(embeddings_batch_perm)
embeddings = unpermute_func(embeddings_perm)
return embeddings