Source code for langchain_community.embeddings.fastembed
from typing import Any, Dict, List, Literal, Optional
import numpy as np
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
[docs]class FastEmbedEmbeddings(BaseModel, Embeddings):
"""Qdrant FastEmbedding 模型。
FastEmbed 是一个轻量级、快速的 Python 库,用于生成嵌入。
更多文档请参见:
* https://github.com/qdrant/fastembed/
* https://qdrant.github.io/fastembed/
要使用这个类,您必须安装 `fastembed` Python 包。
`pip install fastembed`
示例:
from langchain_community.embeddings import FastEmbedEmbeddings
fastembed = FastEmbedEmbeddings()"""
model_name: str = "BAAI/bge-small-en-v1.5"
"""要使用的FastEmbedding模型的名称
默认为"BAAI/bge-small-en-v1.5"
可在以下网址找到支持的模型列表
https://qdrant.github.io/fastembed/examples/Supported_Models/"""
max_length: int = 512
"""最大令牌数。默认为512。
对于大于512的值,行为未知。"""
cache_dir: Optional[str]
"""缓存目录的路径。
默认为父目录中的 `local_cache`"""
threads: Optional[int]
"""单个onnxruntime会话可以使用的线程数。
默认为None"""
doc_embed_type: Literal["default", "passage"] = "default"
"""文档使用的嵌入类型
可用选项为:"default" 和 "passage" """
_model: Any # : :meta private:
class Config:
"""此pydantic对象的配置。"""
extra = Extra.forbid
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""验证FastEmbed是否已安装。"""
model_name = values.get("model_name")
max_length = values.get("max_length")
cache_dir = values.get("cache_dir")
threads = values.get("threads")
try:
# >= v0.2.0
from fastembed import TextEmbedding
values["_model"] = TextEmbedding(
model_name=model_name,
max_length=max_length,
cache_dir=cache_dir,
threads=threads,
)
except ImportError as ie:
try:
# < v0.2.0
from fastembed.embedding import FlagEmbedding
values["_model"] = FlagEmbedding(
model_name=model_name,
max_length=max_length,
cache_dir=cache_dir,
threads=threads,
)
except ImportError:
raise ImportError(
"Could not import 'fastembed' Python package. "
"Please install it with `pip install fastembed`."
) from ie
return values
[docs] def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""使用FastEmbed为文档生成嵌入。
参数:
texts: 要嵌入的文本列表。
返回:
每个文本的嵌入列表。
"""
embeddings: List[np.ndarray]
if self.doc_embed_type == "passage":
embeddings = self._model.passage_embed(texts)
else:
embeddings = self._model.embed(texts)
return [e.tolist() for e in embeddings]
[docs] def embed_query(self, text: str) -> List[float]:
"""使用FastEmbed生成查询嵌入。
参数:
text: 要嵌入的文本。
返回:
文本的嵌入。
"""
query_embeddings: np.ndarray = next(self._model.query_embed(text))
return query_embeddings.tolist()