Source code for langchain_community.embeddings.infinity_local

"""根据MIT许可证编写,Michael Feil 2023。"""

import asyncio
from logging import getLogger
from typing import Any, Dict, List, Optional

from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator

__all__ = ["InfinityEmbeddingsLocal"]

logger = getLogger(__name__)


[docs]class InfinityEmbeddingsLocal(BaseModel, Embeddings): """优化的Infinity嵌入模型。 https://github.com/michaelfeil/infinity 该类部署了一个本地的Infinity实例来嵌入文本。 该类需要异步使用。 Infinity是一个用于与https://github.com/michaelfeil/infinity上的嵌入模型进行交互的类。 示例: .. code-block:: python from langchain_community.embeddings import InfinityEmbeddingsLocal async with InfinityEmbeddingsLocal( model="BAAI/bge-small-en-v1.5", revision=None, device="cpu", ) as embedder: embeddings = await engine.aembed_documents(["text1", "text2"]) """ model: str "来自Hugging Face 的基础模型 id,例如 BAAI/bge-small-en-v1.5" revision: Optional[str] = None "模型版本,来自huggingface的提交哈希" batch_size: int = 32 "推断的内部批处理大小,例如32" device: str = "auto" "推断使用的设备,例如'cpu'、'cuda'或'mps'" backend: str = "torch" "推理后端,例如 'torch'(推荐用于 ROCm/Nvidia)" "或者选择'onnx/tensorrt'的最佳值" model_warmup: bool = True "使用最大批处理大小预热模型。" engine: Any = None #: :meta private: """Infinity的AsyncEmbeddingEngine。""" # LLM call kwargs class Config: """此pydantic对象的配置。""" extra = Extra.forbid @root_validator(allow_reuse=True) def validate_environment(cls, values: Dict) -> Dict: """验证环境中是否存在API密钥和Python包。""" try: from infinity_emb import AsyncEmbeddingEngine # type: ignore except ImportError: raise ImportError( "Please install the " "`pip install 'infinity_emb[optimum,torch]>=0.0.24'` " "package to use the InfinityEmbeddingsLocal." ) logger.debug(f"Using InfinityEmbeddingsLocal with kwargs {values}") values["engine"] = AsyncEmbeddingEngine( model_name_or_path=values["model"], device=values["device"], revision=values["revision"], model_warmup=values["model_warmup"], batch_size=values["batch_size"], engine=values["backend"], ) return values async def __aenter__(self) -> None: """开始后台工作。 建议使用async with语句。 使用InfinityEmbeddingsLocal时,推荐的用法是: async with InfinityEmbeddingsLocal( model="BAAI/bge-small-en-v1.5", revision=None, device="cpu", ) as embedder: embeddings = await engine.aembed_documents(["text1", "text2"]) """ await self.engine.__aenter__() async def __aexit__(self, *args: Any) -> None: """停止后台工作器, 需要释放对PyTorch模型的引用。 """ await self.engine.__aexit__(*args)
[docs] async def aembed_documents(self, texts: List[str]) -> List[List[float]]: """异步调用到Infinity的嵌入端点。 参数: texts:要嵌入的文本列表。 返回: 每个文本的嵌入列表。 """ if not self.engine.running: logger.warning( "Starting Infinity engine on the fly. This is not recommended." "Please start the engine before using it." ) async with self: # spawning threadpool for multithreaded encode, tokenization embeddings, _ = await self.engine.embed(texts) # stopping threadpool on exit logger.warning("Stopped infinity engine after usage.") else: embeddings, _ = await self.engine.embed(texts) return embeddings
[docs] async def aembed_query(self, text: str) -> List[float]: """异步调用到Infinity的嵌入端点。 参数: text:要嵌入的文本。 返回: 文本的嵌入结果。 """ embeddings = await self.aembed_documents([text]) return embeddings[0]
[docs] def embed_documents(self, texts: List[str]) -> List[List[float]]: """ 这个方法只能是异步的。 """ logger.warning( "这个方法只能是异步的。 " "Please use the async version `await aembed_documents`." ) return asyncio.run(self.aembed_documents(texts))
[docs] def embed_query(self, text: str) -> List[float]: """ """ logger.warning( "这个方法只能是异步的。" " Please use the async version `await aembed_query`." ) return asyncio.run(self.aembed_query(text))