Source code for langchain_community.vectorstores.epsilla

"""封装了对Epsilla向量数据库的操作。"""
from __future__ import annotations

import logging
import uuid
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Type

from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore

if TYPE_CHECKING:
    from pyepsilla import vectordb

logger = logging.getLogger()


[docs]class Epsilla(VectorStore): """封装了Epsilla向量数据库。 作为先决条件,您需要安装``pyepsilla``包 并且有一个运行中的Epsilla向量数据库(例如,通过我们的docker镜像) 请参阅以下文档,了解如何运行Epsilla向量数据库: https://epsilla-inc.gitbook.io/epsilladb/quick-start 参数: client (Any): 用于连接的Epsilla客户端。 embeddings (Embeddings): 用于嵌入文本的函数。 db_path (Optional[str]): 数据库将被持久化的路径。 默认为"/tmp/langchain-epsilla"。 db_name (Optional[str]): 给加载的数据库命名。 默认为"langchain_store"。 示例: .. code-block:: python from langchain_community.vectorstores import Epsilla from pyepsilla import vectordb client = vectordb.Client() embeddings = OpenAIEmbeddings() db_path = "/tmp/vectorstore" db_name = "langchain_store" epsilla = Epsilla(client, embeddings, db_path, db_name) """ _LANGCHAIN_DEFAULT_DB_NAME = "langchain_store" _LANGCHAIN_DEFAULT_DB_PATH = "/tmp/langchain-epsilla" _LANGCHAIN_DEFAULT_TABLE_NAME = "langchain_collection"
[docs] def __init__( self, client: Any, embeddings: Embeddings, db_path: Optional[str] = _LANGCHAIN_DEFAULT_DB_PATH, db_name: Optional[str] = _LANGCHAIN_DEFAULT_DB_NAME, ): """使用必要的组件进行初始化。""" try: import pyepsilla except ImportError as e: raise ImportError( "Could not import pyepsilla python package. " "Please install pyepsilla package with `pip install pyepsilla`." ) from e if not isinstance(client, pyepsilla.vectordb.Client): raise TypeError( f"client should be an instance of pyepsilla.vectordb.Client, " f"got {type(client)}" ) self._client: vectordb.Client = client self._db_name = db_name self._embeddings = embeddings self._collection_name = Epsilla._LANGCHAIN_DEFAULT_TABLE_NAME self._client.load_db(db_name=db_name, db_path=db_path) self._client.use_db(db_name=db_name)
@property def embeddings(self) -> Optional[Embeddings]: return self._embeddings
[docs] def use_collection(self, collection_name: str) -> None: """设置默认使用的集合。 参数: collection_name (str): 集合的名称。 """ self._collection_name = collection_name
[docs] def clear_data(self, collection_name: str = "") -> None: """清除集合中的数据。 参数: collection_name(可选[str]):集合的名称。 如果未提供,则将使用默认集合。 """ if not collection_name: collection_name = self._collection_name self._client.drop_table(collection_name)
[docs] def get( self, collection_name: str = "", response_fields: Optional[List[str]] = None ) -> List[dict]: """获取集合。 参数: collection_name(可选[str]):要从中检索数据的集合名称。 如果未提供,则将使用默认集合。 response_fields(可选[List[str]):结果中字段名称的列表。 如果未指定,将响应所有可用字段。 返回: 检索到的数据列表。 """ if not collection_name: collection_name = self._collection_name status_code, response = self._client.get( table_name=collection_name, response_fields=response_fields ) if status_code != 200: logger.error(f"Failed to get records: {response['message']}") raise Exception("Error: {}.".format(response["message"])) return response["result"]
def _create_collection( self, table_name: str, embeddings: list, metadatas: Optional[list[dict]] = None ) -> None: if not embeddings: raise ValueError("Embeddings list is empty.") dim = len(embeddings[0]) fields: List[dict] = [ {"name": "id", "dataType": "INT"}, {"name": "text", "dataType": "STRING"}, {"name": "embeddings", "dataType": "VECTOR_FLOAT", "dimensions": dim}, ] if metadatas is not None: field_names = [field["name"] for field in fields] for metadata in metadatas: for key, value in metadata.items(): if key in field_names: continue d_type: str if isinstance(value, str): d_type = "STRING" elif isinstance(value, int): d_type = "INT" elif isinstance(value, float): d_type = "FLOAT" elif isinstance(value, bool): d_type = "BOOL" else: raise ValueError(f"Unsupported data type for {key}.") fields.append({"name": key, "dataType": d_type}) field_names.append(key) status_code, response = self._client.create_table( table_name, table_fields=fields ) if status_code != 200: if status_code == 409: logger.info(f"Continuing with the existing table {table_name}.") else: logger.error( f"Failed to create collection {table_name}: {response['message']}" ) raise Exception("Error: {}.".format(response["message"]))
[docs] def add_texts( self, texts: Iterable[str], metadatas: Optional[List[dict]] = None, collection_name: Optional[str] = "", drop_old: Optional[bool] = False, **kwargs: Any, ) -> List[str]: """将文本嵌入并将其添加到数据库中。 参数: texts(Iterable[str]):要嵌入的文本。 metadatas(Optional[List[dict]]):附加到每个文本的元数据字典。默认为None。 collection_name(Optional[str]):要使用的集合名称。默认为“langchain_collection”。 如果提供,将设置默认集合名称。 drop_old(Optional[bool]):是否删除先前的集合并创建新集合。默认为False。 返回: 添加的文本的id列表。 """ if not collection_name: collection_name = self._collection_name else: self._collection_name = collection_name if drop_old: self._client.drop_db(db_name=collection_name) texts = list(texts) try: embeddings = self._embeddings.embed_documents(texts) except NotImplementedError: embeddings = [self._embeddings.embed_query(x) for x in texts] if len(embeddings) == 0: logger.debug("Nothing to insert, skipping.") return [] self._create_collection( table_name=collection_name, embeddings=embeddings, metadatas=metadatas ) ids = [hash(uuid.uuid4()) for _ in texts] records = [] for index, id in enumerate(ids): record = { "id": id, "text": texts[index], "embeddings": embeddings[index], } if metadatas is not None: metadata = metadatas[index].items() for key, value in metadata: record[key] = value records.append(record) status_code, response = self._client.insert( table_name=collection_name, records=records ) if status_code != 200: logger.error( f"Failed to add records to {collection_name}: {response['message']}" ) raise Exception("Error: {}.".format(response["message"])) return [str(id) for id in ids]
[docs] @classmethod def from_texts( cls: Type[Epsilla], texts: List[str], embedding: Embeddings, metadatas: Optional[List[dict]] = None, client: Any = None, db_path: Optional[str] = _LANGCHAIN_DEFAULT_DB_PATH, db_name: Optional[str] = _LANGCHAIN_DEFAULT_DB_NAME, collection_name: Optional[str] = _LANGCHAIN_DEFAULT_TABLE_NAME, drop_old: Optional[bool] = False, **kwargs: Any, ) -> Epsilla: """从原始文档创建一个Epsilla向量存储。 参数: texts(List[str]):要插入的文本数据列表。 embeddings(Embeddings):嵌入函数。 client(pyepsilla.vectordb.Client):用于连接的Epsilla客户端。 metadatas(Optional[List[dict]]):每个文本的元数据。 默认为None。 db_path(Optional[str]):数据库将持久化的路径。 默认为"/tmp/langchain-epsilla"。 db_name(Optional[str]):为加载的数据库命名。 默认为"langchain_store"。 collection_name(Optional[str]):要使用的集合。 默认为"langchain_collection"。 如果提供,还将设置默认集合名称。 drop_old(Optional[bool]):是否删除先前的集合 并创建一个新的。默认为False。 返回: Epsilla:Epsilla向量存储。 """ instance = Epsilla(client, embedding, db_path=db_path, db_name=db_name) instance.add_texts( texts, metadatas=metadatas, collection_name=collection_name, drop_old=drop_old, **kwargs, ) return instance
[docs] @classmethod def from_documents( cls: Type[Epsilla], documents: List[Document], embedding: Embeddings, client: Any = None, db_path: Optional[str] = _LANGCHAIN_DEFAULT_DB_PATH, db_name: Optional[str] = _LANGCHAIN_DEFAULT_DB_NAME, collection_name: Optional[str] = _LANGCHAIN_DEFAULT_TABLE_NAME, drop_old: Optional[bool] = False, **kwargs: Any, ) -> Epsilla: """从文档列表创建一个Epsilla向量存储。 参数: texts (List[str]): 要插入的文本数据列表。 embeddings (Embeddings): 嵌入函数。 client (pyepsilla.vectordb.Client): 用于连接的Epsilla客户端。 metadatas (Optional[List[dict]]): 每个文本的元数据。 默认为None。 db_path (Optional[str]): 数据库将持久化的路径。 默认为"/tmp/langchain-epsilla"。 db_name (Optional[str]): 给加载的数据库命名。 默认为"langchain_store"。 collection_name (Optional[str]): 要使用的集合。 默认为"langchain_collection"。 如果提供,将设置默认集合名称。 drop_old (Optional[bool]): 是否删除先前的集合并创建新集合。 默认为False。 返回: Epsilla: Epsilla向量存储。 """ texts = [doc.page_content for doc in documents] metadatas = [doc.metadata for doc in documents] return cls.from_texts( texts, embedding, metadatas=metadatas, client=client, db_path=db_path, db_name=db_name, collection_name=collection_name, drop_old=drop_old, **kwargs, )