"""封装了对Epsilla向量数据库的操作。"""
from __future__ import annotations
import logging
import uuid
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Type
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore
if TYPE_CHECKING:
from pyepsilla import vectordb
logger = logging.getLogger()
[docs]class Epsilla(VectorStore):
"""封装了Epsilla向量数据库。
作为先决条件,您需要安装``pyepsilla``包
并且有一个运行中的Epsilla向量数据库(例如,通过我们的docker镜像)
请参阅以下文档,了解如何运行Epsilla向量数据库:
https://epsilla-inc.gitbook.io/epsilladb/quick-start
参数:
client (Any): 用于连接的Epsilla客户端。
embeddings (Embeddings): 用于嵌入文本的函数。
db_path (Optional[str]): 数据库将被持久化的路径。
默认为"/tmp/langchain-epsilla"。
db_name (Optional[str]): 给加载的数据库命名。
默认为"langchain_store"。
示例:
.. code-block:: python
from langchain_community.vectorstores import Epsilla
from pyepsilla import vectordb
client = vectordb.Client()
embeddings = OpenAIEmbeddings()
db_path = "/tmp/vectorstore"
db_name = "langchain_store"
epsilla = Epsilla(client, embeddings, db_path, db_name)
"""
_LANGCHAIN_DEFAULT_DB_NAME = "langchain_store"
_LANGCHAIN_DEFAULT_DB_PATH = "/tmp/langchain-epsilla"
_LANGCHAIN_DEFAULT_TABLE_NAME = "langchain_collection"
[docs] def __init__(
self,
client: Any,
embeddings: Embeddings,
db_path: Optional[str] = _LANGCHAIN_DEFAULT_DB_PATH,
db_name: Optional[str] = _LANGCHAIN_DEFAULT_DB_NAME,
):
"""使用必要的组件进行初始化。"""
try:
import pyepsilla
except ImportError as e:
raise ImportError(
"Could not import pyepsilla python package. "
"Please install pyepsilla package with `pip install pyepsilla`."
) from e
if not isinstance(client, pyepsilla.vectordb.Client):
raise TypeError(
f"client should be an instance of pyepsilla.vectordb.Client, "
f"got {type(client)}"
)
self._client: vectordb.Client = client
self._db_name = db_name
self._embeddings = embeddings
self._collection_name = Epsilla._LANGCHAIN_DEFAULT_TABLE_NAME
self._client.load_db(db_name=db_name, db_path=db_path)
self._client.use_db(db_name=db_name)
@property
def embeddings(self) -> Optional[Embeddings]:
return self._embeddings
[docs] def use_collection(self, collection_name: str) -> None:
"""设置默认使用的集合。
参数:
collection_name (str): 集合的名称。
"""
self._collection_name = collection_name
[docs] def clear_data(self, collection_name: str = "") -> None:
"""清除集合中的数据。
参数:
collection_name(可选[str]):集合的名称。
如果未提供,则将使用默认集合。
"""
if not collection_name:
collection_name = self._collection_name
self._client.drop_table(collection_name)
[docs] def get(
self, collection_name: str = "", response_fields: Optional[List[str]] = None
) -> List[dict]:
"""获取集合。
参数:
collection_name(可选[str]):要从中检索数据的集合名称。
如果未提供,则将使用默认集合。
response_fields(可选[List[str]):结果中字段名称的列表。
如果未指定,将响应所有可用字段。
返回:
检索到的数据列表。
"""
if not collection_name:
collection_name = self._collection_name
status_code, response = self._client.get(
table_name=collection_name, response_fields=response_fields
)
if status_code != 200:
logger.error(f"Failed to get records: {response['message']}")
raise Exception("Error: {}.".format(response["message"]))
return response["result"]
def _create_collection(
self, table_name: str, embeddings: list, metadatas: Optional[list[dict]] = None
) -> None:
if not embeddings:
raise ValueError("Embeddings list is empty.")
dim = len(embeddings[0])
fields: List[dict] = [
{"name": "id", "dataType": "INT"},
{"name": "text", "dataType": "STRING"},
{"name": "embeddings", "dataType": "VECTOR_FLOAT", "dimensions": dim},
]
if metadatas is not None:
field_names = [field["name"] for field in fields]
for metadata in metadatas:
for key, value in metadata.items():
if key in field_names:
continue
d_type: str
if isinstance(value, str):
d_type = "STRING"
elif isinstance(value, int):
d_type = "INT"
elif isinstance(value, float):
d_type = "FLOAT"
elif isinstance(value, bool):
d_type = "BOOL"
else:
raise ValueError(f"Unsupported data type for {key}.")
fields.append({"name": key, "dataType": d_type})
field_names.append(key)
status_code, response = self._client.create_table(
table_name, table_fields=fields
)
if status_code != 200:
if status_code == 409:
logger.info(f"Continuing with the existing table {table_name}.")
else:
logger.error(
f"Failed to create collection {table_name}: {response['message']}"
)
raise Exception("Error: {}.".format(response["message"]))
[docs] def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
collection_name: Optional[str] = "",
drop_old: Optional[bool] = False,
**kwargs: Any,
) -> List[str]:
"""将文本嵌入并将其添加到数据库中。
参数:
texts(Iterable[str]):要嵌入的文本。
metadatas(Optional[List[dict]]):附加到每个文本的元数据字典。默认为None。
collection_name(Optional[str]):要使用的集合名称。默认为“langchain_collection”。
如果提供,将设置默认集合名称。
drop_old(Optional[bool]):是否删除先前的集合并创建新集合。默认为False。
返回:
添加的文本的id列表。
"""
if not collection_name:
collection_name = self._collection_name
else:
self._collection_name = collection_name
if drop_old:
self._client.drop_db(db_name=collection_name)
texts = list(texts)
try:
embeddings = self._embeddings.embed_documents(texts)
except NotImplementedError:
embeddings = [self._embeddings.embed_query(x) for x in texts]
if len(embeddings) == 0:
logger.debug("Nothing to insert, skipping.")
return []
self._create_collection(
table_name=collection_name, embeddings=embeddings, metadatas=metadatas
)
ids = [hash(uuid.uuid4()) for _ in texts]
records = []
for index, id in enumerate(ids):
record = {
"id": id,
"text": texts[index],
"embeddings": embeddings[index],
}
if metadatas is not None:
metadata = metadatas[index].items()
for key, value in metadata:
record[key] = value
records.append(record)
status_code, response = self._client.insert(
table_name=collection_name, records=records
)
if status_code != 200:
logger.error(
f"Failed to add records to {collection_name}: {response['message']}"
)
raise Exception("Error: {}.".format(response["message"]))
return [str(id) for id in ids]
[docs] def similarity_search(
self, query: str, k: int = 4, collection_name: str = "", **kwargs: Any
) -> List[Document]:
"""返回与查询语义最相关的文档。
参数:
query(str):用于查询向量存储的字符串。
k(可选[int]):要返回的文档数量。默认为4。
collection_name(可选[str]):要使用的集合。默认为“langchain_store”或之前提供的集合。
返回:
与查询语义最相关的文档列表。
"""
if not collection_name:
collection_name = self._collection_name
query_vector = self._embeddings.embed_query(query)
status_code, response = self._client.query(
table_name=collection_name,
query_field="embeddings",
query_vector=query_vector,
limit=k,
)
if status_code != 200:
logger.error(f"Search failed: {response['message']}.")
raise Exception("Error: {}.".format(response["message"]))
exclude_keys = ["id", "text", "embeddings"]
return list(
map(
lambda item: Document(
page_content=item["text"],
metadata={
key: item[key] for key in item if key not in exclude_keys
},
),
response["result"],
)
)
[docs] @classmethod
def from_texts(
cls: Type[Epsilla],
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
client: Any = None,
db_path: Optional[str] = _LANGCHAIN_DEFAULT_DB_PATH,
db_name: Optional[str] = _LANGCHAIN_DEFAULT_DB_NAME,
collection_name: Optional[str] = _LANGCHAIN_DEFAULT_TABLE_NAME,
drop_old: Optional[bool] = False,
**kwargs: Any,
) -> Epsilla:
"""从原始文档创建一个Epsilla向量存储。
参数:
texts(List[str]):要插入的文本数据列表。
embeddings(Embeddings):嵌入函数。
client(pyepsilla.vectordb.Client):用于连接的Epsilla客户端。
metadatas(Optional[List[dict]]):每个文本的元数据。
默认为None。
db_path(Optional[str]):数据库将持久化的路径。
默认为"/tmp/langchain-epsilla"。
db_name(Optional[str]):为加载的数据库命名。
默认为"langchain_store"。
collection_name(Optional[str]):要使用的集合。
默认为"langchain_collection"。
如果提供,还将设置默认集合名称。
drop_old(Optional[bool]):是否删除先前的集合
并创建一个新的。默认为False。
返回:
Epsilla:Epsilla向量存储。
"""
instance = Epsilla(client, embedding, db_path=db_path, db_name=db_name)
instance.add_texts(
texts,
metadatas=metadatas,
collection_name=collection_name,
drop_old=drop_old,
**kwargs,
)
return instance
[docs] @classmethod
def from_documents(
cls: Type[Epsilla],
documents: List[Document],
embedding: Embeddings,
client: Any = None,
db_path: Optional[str] = _LANGCHAIN_DEFAULT_DB_PATH,
db_name: Optional[str] = _LANGCHAIN_DEFAULT_DB_NAME,
collection_name: Optional[str] = _LANGCHAIN_DEFAULT_TABLE_NAME,
drop_old: Optional[bool] = False,
**kwargs: Any,
) -> Epsilla:
"""从文档列表创建一个Epsilla向量存储。
参数:
texts (List[str]): 要插入的文本数据列表。
embeddings (Embeddings): 嵌入函数。
client (pyepsilla.vectordb.Client): 用于连接的Epsilla客户端。
metadatas (Optional[List[dict]]): 每个文本的元数据。
默认为None。
db_path (Optional[str]): 数据库将持久化的路径。
默认为"/tmp/langchain-epsilla"。
db_name (Optional[str]): 给加载的数据库命名。
默认为"langchain_store"。
collection_name (Optional[str]): 要使用的集合。
默认为"langchain_collection"。
如果提供,将设置默认集合名称。
drop_old (Optional[bool]): 是否删除先前的集合并创建新集合。
默认为False。
返回:
Epsilla: Epsilla向量存储。
"""
texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
return cls.from_texts(
texts,
embedding,
metadatas=metadatas,
client=client,
db_path=db_path,
db_name=db_name,
collection_name=collection_name,
drop_old=drop_old,
**kwargs,
)