Source code for langchain_community.embeddings.infinity_local
"""根据MIT许可证编写,Michael Feil 2023。"""
import asyncio
from logging import getLogger
from typing import Any, Dict, List, Optional
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
__all__ = ["InfinityEmbeddingsLocal"]
logger = getLogger(__name__)
[docs]class InfinityEmbeddingsLocal(BaseModel, Embeddings):
"""优化的Infinity嵌入模型。
https://github.com/michaelfeil/infinity
该类部署了一个本地的Infinity实例来嵌入文本。
该类需要异步使用。
Infinity是一个用于与https://github.com/michaelfeil/infinity上的嵌入模型进行交互的类。
示例:
.. code-block:: python
from langchain_community.embeddings import InfinityEmbeddingsLocal
async with InfinityEmbeddingsLocal(
model="BAAI/bge-small-en-v1.5",
revision=None,
device="cpu",
) as embedder:
embeddings = await engine.aembed_documents(["text1", "text2"])
"""
model: str
"来自Hugging Face 的基础模型 id,例如 BAAI/bge-small-en-v1.5"
revision: Optional[str] = None
"模型版本,来自huggingface的提交哈希"
batch_size: int = 32
"推断的内部批处理大小,例如32"
device: str = "auto"
"推断使用的设备,例如'cpu'、'cuda'或'mps'"
backend: str = "torch"
"推理后端,例如 'torch'(推荐用于 ROCm/Nvidia)"
"或者选择'onnx/tensorrt'的最佳值"
model_warmup: bool = True
"使用最大批处理大小预热模型。"
engine: Any = None #: :meta private:
"""Infinity的AsyncEmbeddingEngine。"""
# LLM call kwargs
class Config:
"""此pydantic对象的配置。"""
extra = Extra.forbid
@root_validator(allow_reuse=True)
def validate_environment(cls, values: Dict) -> Dict:
"""验证环境中是否存在API密钥和Python包。"""
try:
from infinity_emb import AsyncEmbeddingEngine # type: ignore
except ImportError:
raise ImportError(
"Please install the "
"`pip install 'infinity_emb[optimum,torch]>=0.0.24'` "
"package to use the InfinityEmbeddingsLocal."
)
logger.debug(f"Using InfinityEmbeddingsLocal with kwargs {values}")
values["engine"] = AsyncEmbeddingEngine(
model_name_or_path=values["model"],
device=values["device"],
revision=values["revision"],
model_warmup=values["model_warmup"],
batch_size=values["batch_size"],
engine=values["backend"],
)
return values
async def __aenter__(self) -> None:
"""开始后台工作。
建议使用async with语句。
使用InfinityEmbeddingsLocal时,推荐的用法是:
async with InfinityEmbeddingsLocal(
model="BAAI/bge-small-en-v1.5",
revision=None,
device="cpu",
) as embedder:
embeddings = await engine.aembed_documents(["text1", "text2"])
"""
await self.engine.__aenter__()
async def __aexit__(self, *args: Any) -> None:
"""停止后台工作器,
需要释放对PyTorch模型的引用。
"""
await self.engine.__aexit__(*args)
[docs] async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""异步调用到Infinity的嵌入端点。
参数:
texts:要嵌入的文本列表。
返回:
每个文本的嵌入列表。
"""
if not self.engine.running:
logger.warning(
"Starting Infinity engine on the fly. This is not recommended."
"Please start the engine before using it."
)
async with self:
# spawning threadpool for multithreaded encode, tokenization
embeddings, _ = await self.engine.embed(texts)
# stopping threadpool on exit
logger.warning("Stopped infinity engine after usage.")
else:
embeddings, _ = await self.engine.embed(texts)
return embeddings
[docs] async def aembed_query(self, text: str) -> List[float]:
"""异步调用到Infinity的嵌入端点。
参数:
text:要嵌入的文本。
返回:
文本的嵌入结果。
"""
embeddings = await self.aembed_documents([text])
return embeddings[0]
[docs] def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""
这个方法只能是异步的。
"""
logger.warning(
"这个方法只能是异步的。 "
"Please use the async version `await aembed_documents`."
)
return asyncio.run(self.aembed_documents(texts))
[docs] def embed_query(self, text: str) -> List[float]:
""" """
logger.warning(
"这个方法只能是异步的。"
" Please use the async version `await aembed_query`."
)
return asyncio.run(self.aembed_query(text))