Source code for langchain_community.embeddings.fastembed

from typing import Any, Dict, List, Literal, Optional

import numpy as np
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator


[docs]class FastEmbedEmbeddings(BaseModel, Embeddings): """Qdrant FastEmbedding 模型。 FastEmbed 是一个轻量级、快速的 Python 库,用于生成嵌入。 更多文档请参见: * https://github.com/qdrant/fastembed/ * https://qdrant.github.io/fastembed/ 要使用这个类,您必须安装 `fastembed` Python 包。 `pip install fastembed` 示例: from langchain_community.embeddings import FastEmbedEmbeddings fastembed = FastEmbedEmbeddings()""" model_name: str = "BAAI/bge-small-en-v1.5" """要使用的FastEmbedding模型的名称 默认为"BAAI/bge-small-en-v1.5" 可在以下网址找到支持的模型列表 https://qdrant.github.io/fastembed/examples/Supported_Models/""" max_length: int = 512 """最大令牌数。默认为512。 对于大于512的值,行为未知。""" cache_dir: Optional[str] """缓存目录的路径。 默认为父目录中的 `local_cache`""" threads: Optional[int] """单个onnxruntime会话可以使用的线程数。 默认为None""" doc_embed_type: Literal["default", "passage"] = "default" """文档使用的嵌入类型 可用选项为:"default" 和 "passage" """ _model: Any # : :meta private: class Config: """此pydantic对象的配置。""" extra = Extra.forbid @root_validator() def validate_environment(cls, values: Dict) -> Dict: """验证FastEmbed是否已安装。""" model_name = values.get("model_name") max_length = values.get("max_length") cache_dir = values.get("cache_dir") threads = values.get("threads") try: # >= v0.2.0 from fastembed import TextEmbedding values["_model"] = TextEmbedding( model_name=model_name, max_length=max_length, cache_dir=cache_dir, threads=threads, ) except ImportError as ie: try: # < v0.2.0 from fastembed.embedding import FlagEmbedding values["_model"] = FlagEmbedding( model_name=model_name, max_length=max_length, cache_dir=cache_dir, threads=threads, ) except ImportError: raise ImportError( "Could not import 'fastembed' Python package. " "Please install it with `pip install fastembed`." ) from ie return values
[docs] def embed_documents(self, texts: List[str]) -> List[List[float]]: """使用FastEmbed为文档生成嵌入。 参数: texts: 要嵌入的文本列表。 返回: 每个文本的嵌入列表。 """ embeddings: List[np.ndarray] if self.doc_embed_type == "passage": embeddings = self._model.passage_embed(texts) else: embeddings = self._model.embed(texts) return [e.tolist() for e in embeddings]
[docs] def embed_query(self, text: str) -> List[float]: """使用FastEmbed生成查询嵌入。 参数: text: 要嵌入的文本。 返回: 文本的嵌入。 """ query_embeddings: np.ndarray = next(self._model.query_embed(text)) return query_embeddings.tolist()