Source code for langchain_community.embeddings.infinity

"""根据MIT许可证编写,Michael Feil 2023。"""

import asyncio
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Callable, Dict, List, Optional, Tuple

import aiohttp
import numpy as np
import requests
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
from langchain_core.utils import get_from_dict_or_env

__all__ = ["InfinityEmbeddings"]


[docs]class InfinityEmbeddings(BaseModel, Embeddings): """自托管的嵌入模型,用于 `infinity` 包。 请参阅 https://github.com/michaelfeil/infinity 这也适用于文本嵌入推理和其他 自托管的openai兼容服务器。 Infinity 是一个用于与 https://github.com/michaelfeil/infinity 上的嵌入模型进行交互的包 示例: .. code-block:: python from langchain_community.embeddings import InfinityEmbeddings InfinityEmbeddings( model="BAAI/bge-small", infinity_api_url="http://localhost:7997", ) """ model: str "基础的Infinity模型id。" infinity_api_url: str = "http://localhost:7997" """要使用的端点URL。""" client: Any = None #: :meta private: """无限客户端。""" # LLM call kwargs class Config: """此pydantic对象的配置。""" extra = Extra.forbid @root_validator(allow_reuse=True) def validate_environment(cls, values: Dict) -> Dict: """验证环境中是否存在API密钥和Python包。""" values["infinity_api_url"] = get_from_dict_or_env( values, "infinity_api_url", "INFINITY_API_URL" ) values["client"] = TinyAsyncOpenAIInfinityEmbeddingClient( host=values["infinity_api_url"], ) return values
[docs] def embed_documents(self, texts: List[str]) -> List[List[float]]: """调用Infinity的嵌入端点。 参数: texts:要嵌入的文本列表。 返回: 每个文本的嵌入列表。 """ embeddings = self.client.embed( model=self.model, texts=texts, ) return embeddings
[docs] async def aembed_documents(self, texts: List[str]) -> List[List[float]]: """异步调用到Infinity的嵌入端点。 参数: texts:要嵌入的文本列表。 返回: 每个文本的嵌入列表。 """ embeddings = await self.client.aembed( model=self.model, texts=texts, ) return embeddings
[docs] def embed_query(self, text: str) -> List[float]: """调用Infinity的嵌入端点。 参数: text:要嵌入的文本。 返回: 文本的嵌入。 """ return self.embed_documents([text])[0]
[docs] async def aembed_query(self, text: str) -> List[float]: """异步调用到Infinity的嵌入端点。 参数: text:要嵌入的文本。 返回: 文本的嵌入结果。 """ embeddings = await self.aembed_documents([text]) return embeddings[0]
[docs]class TinyAsyncOpenAIInfinityEmbeddingClient: #: :meta private: """辅助工具,用于嵌入Infinity。 这不是Langchain稳定API的一部分,不建议直接使用。 示例: ```python mini_client = TinyAsyncInfinityEmbeddingClient( ) embeds = mini_client.embed( model="BAAI/bge-small", text=["doc1", "doc2"] ) # 或者 embeds = await mini_client.aembed( model="BAAI/bge-small", text=["doc1", "doc2"] ) ```"""
[docs] def __init__( self, host: str = "http://localhost:7797/v1", aiosession: Optional[aiohttp.ClientSession] = None, ) -> None: self.host = host self.aiosession = aiosession if self.host is None or len(self.host) < 3: raise ValueError(" param `host` must be set to a valid url") self._batch_size = 128
@staticmethod def _permute( texts: List[str], sorter: Callable = len ) -> Tuple[List[str], Callable]: """将文本按升序排序,并返回一个lambda表达式,用于对相同长度的列表进行排序 https://github.com/UKPLab/sentence-transformers/blob/ c5f93f70eca933c78695c5bc686ceda59651ae3b/sentence_transformers/SentenceTransformer.py#L156 参数: texts (List[str]): _描述_ sorter (Callable, optional): _描述_. 默认为 len. 返回: Tuple[List[str], Callable]: _描述_ 示例: ``` texts = ["one","three","four"] perm_texts, undo = self._permute(texts) texts == undo(perm_texts) ``` """ if len(texts) == 1: # special case query return texts, lambda t: t length_sorted_idx = np.argsort([-sorter(sen) for sen in texts]) texts_sorted = [texts[idx] for idx in length_sorted_idx] return texts_sorted, lambda unsorted_embeddings: [ # noqa E731 unsorted_embeddings[idx] for idx in np.argsort(length_sorted_idx) ] def _batch(self, texts: List[str]) -> List[List[str]]: """将文本部分的列表分成大小为`self._batch_size`的批次 在对向量数据库进行编码时, 参数: texts (List[str]): 句子列表 self._batch_size (int, optional): 单个请求的最大批次大小。 返回: List[List[str]]: 句子列表的批次 """ if len(texts) == 1: # special case query return [texts] batches = [] for start_index in range(0, len(texts), self._batch_size): batches.append(texts[start_index : start_index + self._batch_size]) return batches @staticmethod def _unbatch(batch_of_texts: List[List[Any]]) -> List[Any]: if len(batch_of_texts) == 1 and len(batch_of_texts[0]) == 1: # special case query return batch_of_texts[0] texts = [] for sublist in batch_of_texts: texts.extend(sublist) return texts def _kwargs_post_request(self, model: str, texts: List[str]) -> Dict[str, Any]: """构建用于同步的Post请求的kwargs 参数: model (str): _描述_ texts (List[str]): _描述_ 返回: Dict[str, Collection[str]]: _描述_ """ return dict( url=f"{self.host}/embeddings", headers={ # "accept": "application/json", "content-type": "application/json", }, json=dict( input=texts, model=model, ), ) def _sync_request_embed( self, model: str, batch_texts: List[str] ) -> List[List[float]]: response = requests.post( **self._kwargs_post_request(model=model, texts=batch_texts) ) if response.status_code != 200: raise Exception( f"Infinity returned an unexpected response with status " f"{response.status_code}: {response.text}" ) return [e["embedding"] for e in response.json()["data"]]
[docs] def embed(self, model: str, texts: List[str]) -> List[List[float]]: """调用模型的嵌入 参数: model (str): 要嵌入的模型 texts (List[str]): 要嵌入的句子列表 返回: List[List[float]]: 每个句子的向量列表 """ perm_texts, unpermute_func = self._permute(texts) perm_texts_batched = self._batch(perm_texts) # Request map_args = ( self._sync_request_embed, [model] * len(perm_texts_batched), perm_texts_batched, ) if len(perm_texts_batched) == 1: embeddings_batch_perm = list(map(*map_args)) else: with ThreadPoolExecutor(32) as p: embeddings_batch_perm = list(p.map(*map_args)) embeddings_perm = self._unbatch(embeddings_batch_perm) embeddings = unpermute_func(embeddings_perm) return embeddings
async def _async_request( self, session: aiohttp.ClientSession, kwargs: Dict[str, Any] ) -> List[List[float]]: async with session.post(**kwargs) as response: if response.status != 200: raise Exception( f"Infinity returned an unexpected response with status " f"{response.status}: {response.text}" ) embedding = (await response.json())["embeddings"] return [e["embedding"] for e in embedding]
[docs] async def aembed(self, model: str, texts: List[str]) -> List[List[float]]: """调用模型的嵌入,异步方法 参数: model (str): 嵌入模型 texts (List[str]): 要嵌入的句子列表 返回: List[List[float]]: 每个句子的向量列表 """ perm_texts, unpermute_func = self._permute(texts) perm_texts_batched = self._batch(perm_texts) # Request if self.aiosession is None: self.aiosession = aiohttp.ClientSession( trust_env=True, connector=aiohttp.TCPConnector(limit=32) ) async with self.aiosession as session: embeddings_batch_perm = await asyncio.gather( *[ self._async_request( session=session, **self._kwargs_post_request(model=model, texts=t), ) for t in perm_texts_batched ] ) embeddings_perm = self._unbatch(embeddings_batch_perm) embeddings = unpermute_func(embeddings_perm) return embeddings