import asyncio
from typing import (
Any,
Iterable,
List,
Optional,
Tuple,
)
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore
[docs]class SurrealDBStore(VectorStore):
"""SurrealDB作为向量存储。
要使用,您应该已安装``surrealdb`` python包。
参数:
embedding_function: 要使用的嵌入函数。
dburl: SurrealDB连接url
ns: 用于向量存储的surrealdb命名空间。 (默认值: "langchain")
db: 用于向量存储的surrealdb数据库。 (默认值: "database")
collection: 用于向量存储的surrealdb集合。
(默认值: "documents")
(可选) db_user和db_pass: surrealdb凭据
示例:
.. code-block:: python
from langchain_community.vectorstores.surrealdb import SurrealDBStore
from langchain_community.embeddings import HuggingFaceEmbeddings
embedding_function = HuggingFaceEmbeddings()
dburl = "ws://localhost:8000/rpc"
ns = "langchain"
db = "docstore"
collection = "documents"
db_user = "root"
db_pass = "root"
sdb = SurrealDBStore.from_texts(
texts=texts,
embedding=embedding_function,
dburl,
ns, db, collection,
db_user=db_user, db_pass=db_pass)
"""
[docs] def __init__(
self,
embedding_function: Embeddings,
**kwargs: Any,
) -> None:
try:
from surrealdb import Surreal
except ImportError as e:
raise ImportError(
"""Cannot import from surrealdb.
please install with `pip install surrealdb`."""
) from e
self.dburl = kwargs.pop("dburl", "ws://localhost:8000/rpc")
if self.dburl[0:2] == "ws":
self.sdb = Surreal(self.dburl)
else:
raise ValueError("Only websocket connections are supported at this time.")
self.ns = kwargs.pop("ns", "langchain")
self.db = kwargs.pop("db", "database")
self.collection = kwargs.pop("collection", "documents")
self.embedding_function = embedding_function
self.kwargs = kwargs
[docs] async def initialize(self) -> None:
"""初始化与surrealdb数据库的连接
如果提供了凭据,则进行身份验证
"""
await self.sdb.connect()
if "db_user" in self.kwargs and "db_pass" in self.kwargs:
user = self.kwargs.get("db_user")
password = self.kwargs.get("db_pass")
await self.sdb.signin({"user": user, "pass": password})
await self.sdb.use(self.ns, self.db)
@property
def embeddings(self) -> Optional[Embeddings]:
return (
self.embedding_function
if isinstance(self.embedding_function, Embeddings)
else None
)
[docs] async def aadd_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> List[str]:
"""将文本列表与嵌入异步添加到向量存储库中
参数:
texts (Iterable[str]): 要添加到数据库中的文本集合
返回:
新插入文档的ID列表
"""
embeddings = self.embedding_function.embed_documents(list(texts))
ids = []
for idx, text in enumerate(texts):
data = {"text": text, "embedding": embeddings[idx]}
if metadatas is not None and idx < len(metadatas):
data["metadata"] = metadatas[idx] # type: ignore[assignment]
else:
data["metadata"] = []
record = await self.sdb.create(
self.collection,
data,
)
ids.append(record[0]["id"])
return ids
[docs] def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> List[str]:
"""将文本列表与嵌入一起添加到向量存储中
参数:
texts (Iterable[str]): 要添加到数据库的文本集合
返回:
新插入文档的id列表
"""
async def _add_texts(
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> List[str]:
await self.initialize()
return await self.aadd_texts(texts, metadatas, **kwargs)
return asyncio.run(_add_texts(texts, metadatas, **kwargs))
[docs] async def adelete(
self,
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> Optional[bool]:
"""异步按文档ID删除。
参数:
ids:要删除的ID列表。
**kwargs:子类可能使用的其他关键字参数。
返回:
Optional[bool]:如果删除成功,则为True,否则为False。
"""
if ids is None:
await self.sdb.delete(self.collection)
return True
else:
if isinstance(ids, str):
await self.sdb.delete(ids)
return True
else:
if isinstance(ids, list) and len(ids) > 0:
_ = [await self.sdb.delete(id) for id in ids]
return True
return False
[docs] def delete(
self,
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> Optional[bool]:
"""根据文档ID删除。
参数:
ids:要删除的ID列表。
**kwargs:子类可能使用的其他关键字参数。
返回:
Optional[bool]:如果删除成功则为True,否则为False。
"""
async def _delete(ids: Optional[List[str]], **kwargs: Any) -> Optional[bool]:
await self.initialize()
return await self.adelete(ids=ids, **kwargs)
return asyncio.run(_delete(ids, **kwargs))
async def _asimilarity_search_by_vector_with_score(
self, embedding: List[float], k: int = 4, **kwargs: Any
) -> List[Tuple[Document, float]]:
"""异步运行相似性搜索以查询嵌入,并返回文档和分数
参数:
embedding(List[float]):查询嵌入。
k(int):要返回的结果数量。默认为4。
返回:
最相似的文档列表以及分数
"""
args = {
"collection": self.collection,
"embedding": embedding,
"k": k,
"score_threshold": kwargs.get("score_threshold", 0),
}
query = f"""
select
id,
text,
metadata,
vector::similarity::cosine(embedding, $embedding) as similarity
from ⟨{args["collection"]}⟩
where vector::similarity::cosine(embedding, $embedding) >= $score_threshold
order by similarity desc LIMIT $k;
"""
results = await self.sdb.query(query, args)
if len(results) == 0:
return []
result = results[0]
if result["status"] != "OK":
from surrealdb.ws import SurrealException
err = result.get("result", "Unknown Error")
raise SurrealException(err)
return [
(
Document(
page_content=doc["text"],
metadata={"id": doc["id"], **(doc.get("metadata", None) or {})},
),
doc["similarity"],
)
for doc in result["result"]
]
[docs] async def asimilarity_search_with_relevance_scores(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Tuple[Document, float]]:
"""异步运行相似性搜索并返回相关性分数
参数:
query (str): 查询
k (int): 要返回的结果数量。默认为4。
返回:
最相似的文档列表以及相关性分数
"""
query_embedding = self.embedding_function.embed_query(query)
return [
(document, similarity)
for document, similarity in (
await self._asimilarity_search_by_vector_with_score(
query_embedding, k, **kwargs
)
)
]
[docs] def similarity_search_with_relevance_scores(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Tuple[Document, float]]:
"""运行同步相似性搜索并返回相关性分数
参数:
query (str): 查询
k (int): 要返回的结果数量。默认为4。
返回:
最相似的文档列表以及相关性分数
"""
async def _similarity_search_with_relevance_scores() -> (
List[Tuple[Document, float]]
):
await self.initialize()
return await self.asimilarity_search_with_relevance_scores(
query, k, **kwargs
)
return asyncio.run(_similarity_search_with_relevance_scores())
[docs] async def asimilarity_search_with_score(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Tuple[Document, float]]:
"""异步运行相似性搜索并返回距离分数
参数:
query (str): 查询
k (int): 要返回的结果数量。默认为4。
返回:
最相似的文档列表以及相关性距离分数
"""
query_embedding = self.embedding_function.embed_query(query)
return [
(document, similarity)
for document, similarity in (
await self._asimilarity_search_by_vector_with_score(
query_embedding, k, **kwargs
)
)
]
[docs] def similarity_search_with_score(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Tuple[Document, float]]:
"""运行同步相似性搜索并返回距离分数
参数:
query (str): 查询
k (int): 要返回的结果数量。默认为4。
返回:
最相似的文档列表以及相关性距离分数
"""
async def _similarity_search_with_score() -> List[Tuple[Document, float]]:
await self.initialize()
return await self.asimilarity_search_with_score(query, k, **kwargs)
return asyncio.run(_similarity_search_with_score())
[docs] async def asimilarity_search_by_vector(
self, embedding: List[float], k: int = 4, **kwargs: Any
) -> List[Document]:
"""在查询嵌入上异步运行相似性搜索
参数:
embedding (List[float]): 查询嵌入
k (int): 要返回的结果数量。默认为4。
返回:
与查询最相似的文档列表
"""
return [
document
for document, _ in await self._asimilarity_search_by_vector_with_score(
embedding, k, **kwargs
)
]
[docs] def similarity_search_by_vector(
self, embedding: List[float], k: int = 4, **kwargs: Any
) -> List[Document]:
"""在查询嵌入上运行相似性搜索
参数:
embedding (List[float]): 查询嵌入
k (int): 要返回的结果数量。默认为4。
返回:
与查询最相似的文档列表
"""
async def _similarity_search_by_vector() -> List[Document]:
await self.initialize()
return await self.asimilarity_search_by_vector(embedding, k, **kwargs)
return asyncio.run(_similarity_search_by_vector())
[docs] async def asimilarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]:
"""在查询上异步运行相似性搜索
参数:
query (str): 查询
k (int): 要返回的结果数量。默认为4。
返回:
与查询最相似的文档列表
"""
query_embedding = self.embedding_function.embed_query(query)
return await self.asimilarity_search_by_vector(query_embedding, k, **kwargs)
[docs] def similarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]:
"""在查询上运行相似性搜索
参数:
query(str):查询
k(int):要返回的结果数量。默认为4。
返回:
与查询最相似的文档列表
"""
async def _similarity_search() -> List[Document]:
await self.initialize()
return await self.asimilarity_search(query, k, **kwargs)
return asyncio.run(_similarity_search())
[docs] @classmethod
async def afrom_texts(
cls,
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> "SurrealDBStore":
"""从文本列表异步创建SurrealDBStore
参数:
texts (List[str]): 要进行向量化和存储的文本列表
embedding (Optional[Embeddings]): 嵌入函数。
dburl (str): SurrealDB连接URL
(默认值: "ws://localhost:8000/rpc")
ns (str): 用于向量存储的SurrealDB命名空间。
(默认值: "langchain")
db (str): 用于向量存储的SurrealDB数据库。
(默认值: "database")
collection (str): 用于向量存储的SurrealDB集合。
(默认值: "documents")
(可选) db_user 和 db_pass: SurrealDB凭据
返回:
初始化并准备就绪的SurrealDBStore对象。
"""
sdb = cls(embedding, **kwargs)
await sdb.initialize()
await sdb.aadd_texts(texts, metadatas, **kwargs)
return sdb
[docs] @classmethod
def from_texts(
cls,
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> "SurrealDBStore":
"""从文本列表创建SurrealDBStore
参数:
texts (List[str]): 要进行向量化和存储的文本列表
embedding (Optional[Embeddings]): 嵌入函数。
dburl (str): SurrealDB连接URL
ns (str): 用于向量存储的SurrealDB命名空间。
(默认值: "langchain")
db (str): 用于向量存储的SurrealDB数据库。
(默认值: "database")
collection (str): 用于向量存储的SurrealDB集合。
(默认值: "documents")
(可选) db_user 和 db_pass: SurrealDB凭据
返回:
初始化并准备就绪的SurrealDBStore对象。
"""
sdb = asyncio.run(cls.afrom_texts(texts, embedding, metadatas, **kwargs))
return sdb