from __future__ import annotations
import logging
import uuid
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, Type
import numpy as np
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore
from langchain_community.vectorstores.utils import maximal_marginal_relevance
if TYPE_CHECKING:
import awadb
logger = logging.getLogger()
DEFAULT_TOPN = 4
[docs]class AwaDB(VectorStore):
"""`AwaDB` 向量存储。"""
_DEFAULT_TABLE_NAME = "langchain_awadb"
[docs] def __init__(
self,
table_name: str = _DEFAULT_TABLE_NAME,
embedding: Optional[Embeddings] = None,
log_and_data_dir: Optional[str] = None,
client: Optional[awadb.Client] = None,
**kwargs: Any,
) -> None:
"""使用AwaDB客户端进行初始化。
如果未指定table_name,
将自动生成一个随机的table_name,格式为`_DEFAULT_TABLE_NAME + uuid的最后一段`。
参数:
table_name: 创建的表的名称,默认为_DEFAULT_TABLE_NAME。
embedding: 可选的初始嵌入。
log_and_data_dir: 可选的日志和数据的根目录。
client: 可选的AwaDB客户端。
kwargs: 未来可能扩展的任何参数。
返回:
无。
"""
try:
import awadb
except ImportError:
raise ImportError(
"Could not import awadb python package. "
"Please install it with `pip install awadb`."
)
if client is not None:
self.awadb_client = client
else:
if log_and_data_dir is not None:
self.awadb_client = awadb.Client(log_and_data_dir)
else:
self.awadb_client = awadb.Client()
if table_name == self._DEFAULT_TABLE_NAME:
table_name += "_"
table_name += str(uuid.uuid4()).split("-")[-1]
self.awadb_client.Create(table_name)
self.table2embeddings: dict[str, Embeddings] = {}
if embedding is not None:
self.table2embeddings[table_name] = embedding
self.using_table_name = table_name
@property
def embeddings(self) -> Optional[Embeddings]:
if self.using_table_name in self.table2embeddings:
return self.table2embeddings[self.using_table_name]
return None
[docs] def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
is_duplicate_texts: Optional[bool] = None,
**kwargs: Any,
) -> List[str]:
"""运行更多文本通过嵌入并添加到向量存储中。
参数:
texts:要添加到向量存储中的字符串的可迭代对象。
metadatas:可选的与文本相关联的元数据列表。
is_duplicate_texts:可选是否复制文本。默认为True。
kwargs:将来可能扩展的任何可能的参数。
返回:
将文本添加到向量存储中的ID列表。
"""
if self.awadb_client is None:
raise ValueError("AwaDB client is None!!!")
embeddings = None
if self.using_table_name in self.table2embeddings:
embeddings = self.table2embeddings[self.using_table_name].embed_documents(
list(texts)
)
return self.awadb_client.AddTexts(
"embedding_text",
"text_embedding",
texts,
embeddings,
metadatas,
is_duplicate_texts,
)
[docs] def load_local(
self,
table_name: str,
**kwargs: Any,
) -> bool:
"""加载本地指定的表格。
参数:
table_name: 表格名称
kwargs: 未来可能扩展的任何参数。
返回:
加载本地指定表格的成功或失败。
"""
if self.awadb_client is None:
raise ValueError("AwaDB client is None!!!")
return self.awadb_client.Load(table_name)
[docs] def similarity_search(
self,
query: str,
k: int = DEFAULT_TOPN,
text_in_page_content: Optional[str] = None,
meta_filter: Optional[dict] = None,
**kwargs: Any,
) -> List[Document]:
"""返回与查询最相似的文档。
参数:
query: 文本查询。
k: 要返回的文档的最大数量。
text_in_page_content: 按文档的page_content中的文本进行过滤。
meta_filter (Optional[dict]): 按元数据进行过滤。默认为None。
例如 `{"color" : "red", "price": 4.20}`。可选。
例如 `{"max_price" : 15.66, "min_price": 4.20}`
`price`是元数据字段,表示范围过滤(4.20<'price'<15.66)。
例如 `{"maxe_price" : 15.66, "mine_price": 4.20}`
`price`是元数据字段,表示范围过滤(4.20<='price'<=15.66)。
kwargs: 未来可能的任何扩展参数。
返回:
返回与指定文本查询最相似的k个文档。
"""
if self.awadb_client is None:
raise ValueError("AwaDB client is None!!!")
embedding = None
if self.using_table_name in self.table2embeddings:
embedding = self.table2embeddings[self.using_table_name].embed_query(query)
else:
from awadb import AwaEmbedding
embedding = AwaEmbedding().Embedding(query)
not_include_fields: Set[str] = {"text_embedding", "_id", "score"}
return self.similarity_search_by_vector(
embedding,
k,
text_in_page_content=text_in_page_content,
meta_filter=meta_filter,
not_include_fields_in_metadata=not_include_fields,
)
[docs] def similarity_search_with_score(
self,
query: str,
k: int = DEFAULT_TOPN,
text_in_page_content: Optional[str] = None,
meta_filter: Optional[dict] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""指定查询的最相似的 k 个文档及其得分。
参数:
query: 文本查询。
k: 与文本查询最相似的前 k 个文档。
text_in_page_content: 根据文档的 page_content 中的文本进行过滤。
meta_filter: 根据元数据进行过滤。默认为 None。
kwargs: 未来可能扩展的任何参数。
返回:
指定文本查询的前 k 个最相似的文档。
0 表示不相似,1 表示最相似。
"""
if self.awadb_client is None:
raise ValueError("AwaDB client is None!!!")
embedding = None
if self.using_table_name in self.table2embeddings:
embedding = self.table2embeddings[self.using_table_name].embed_query(query)
else:
from awadb import AwaEmbedding
embedding = AwaEmbedding().Embedding(query)
results: List[Tuple[Document, float]] = []
not_include_fields: Set[str] = {"text_embedding", "_id"}
retrieval_docs = self.similarity_search_by_vector(
embedding,
k,
text_in_page_content=text_in_page_content,
meta_filter=meta_filter,
not_include_fields_in_metadata=not_include_fields,
)
for doc in retrieval_docs:
score = doc.metadata["score"]
del doc.metadata["score"]
doc_tuple = (doc, score)
results.append(doc_tuple)
return results
def _similarity_search_with_relevance_scores(
self,
query: str,
k: int = 4,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
return self.similarity_search_with_score(query, k, **kwargs)
[docs] def similarity_search_by_vector(
self,
embedding: Optional[List[float]] = None,
k: int = DEFAULT_TOPN,
text_in_page_content: Optional[str] = None,
meta_filter: Optional[dict] = None,
not_include_fields_in_metadata: Optional[Set[str]] = None,
**kwargs: Any,
) -> List[Document]:
"""返回与嵌入向量最相似的文档。
参数:
embedding: 要查找相似文档的嵌入。
k: 要返回的文档数量。默认为4。
text_in_page_content: 按文档的page_content中的文本进行过滤。
meta_filter: 按元数据进行过滤。默认为None。
not_incude_fields_in_metadata: 不包括每个文档的元字段。
返回:
与查询向量最相似的文档列表。
"""
if self.awadb_client is None:
raise ValueError("AwaDB client is None!!!")
results: List[Document] = []
if embedding is None:
return results
show_results = self.awadb_client.Search(
embedding,
k,
text_in_page_content=text_in_page_content,
meta_filter=meta_filter,
not_include_fields=not_include_fields_in_metadata,
)
if show_results.__len__() == 0:
return results
for item_detail in show_results[0]["ResultItems"]:
content = ""
meta_data = {}
for item_key in item_detail:
if item_key == "embedding_text":
content = item_detail[item_key]
continue
elif not_include_fields_in_metadata is not None:
if item_key in not_include_fields_in_metadata:
continue
meta_data[item_key] = item_detail[item_key]
results.append(Document(page_content=content, metadata=meta_data))
return results
[docs] def max_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
text_in_page_content: Optional[str] = None,
meta_filter: Optional[dict] = None,
**kwargs: Any,
) -> List[Document]:
"""返回使用最大边际相关性选择的文档。
最大边际相关性优化了与查询的相似性和所选文档之间的多样性。
参数:
query:要查找相似文档的文本。
k:要返回的文档数量。默认为4。
fetch_k:要获取以传递给MMR算法的文档数量。
lambda_mult:0到1之间的数字,确定结果之间多样性的程度,0对应最大多样性,1对应最小多样性。默认为0.5。
text_in_page_content:按文档页面内容中的文本进行过滤。
meta_filter(可选[dict]):按元数据进行过滤。默认为None。
返回:
通过最大边际相关性选择的文档列表。
"""
if self.awadb_client is None:
raise ValueError("AwaDB client is None!!!")
embedding: List[float] = []
if self.using_table_name in self.table2embeddings:
embedding = self.table2embeddings[self.using_table_name].embed_query(query)
else:
from awadb import AwaEmbedding
embedding = AwaEmbedding().Embedding(query)
if embedding.__len__() == 0:
return []
results = self.max_marginal_relevance_search_by_vector(
embedding,
k,
fetch_k,
lambda_mult=lambda_mult,
text_in_page_content=text_in_page_content,
meta_filter=meta_filter,
)
return results
[docs] def max_marginal_relevance_search_by_vector(
self,
embedding: List[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
text_in_page_content: Optional[str] = None,
meta_filter: Optional[dict] = None,
**kwargs: Any,
) -> List[Document]:
"""返回使用最大边际相关性选择的文档。
最大边际相关性优化了与查询的相似性和所选文档之间的多样性。
参数:
embedding:要查找相似文档的嵌入。
k:要返回的文档数量。默认为4。
fetch_k:要获取以传递给MMR算法的文档数量。
lambda_mult:介于0和1之间的数字,确定结果之间多样性的程度,0对应最大多样性,1对应最小多样性。默认为0.5。
text_in_page_content:按文档的page_content中的文本进行过滤。
meta_filter(可选[dict]):按元数据进行过滤。默认为None。
返回:
通过最大边际相关性选择的文档列表。
"""
if self.awadb_client is None:
raise ValueError("AwaDB client is None!!!")
results: List[Document] = []
if embedding is None:
return results
not_include_fields: set = {"_id", "score"}
retrieved_docs = self.similarity_search_by_vector(
embedding,
fetch_k,
text_in_page_content=text_in_page_content,
meta_filter=meta_filter,
not_include_fields_in_metadata=not_include_fields,
)
top_embeddings = []
for doc in retrieved_docs:
top_embeddings.append(doc.metadata["text_embedding"])
selected_docs = maximal_marginal_relevance(
np.array(embedding, dtype=np.float32), embedding_list=top_embeddings
)
for s_id in selected_docs:
if "text_embedding" in retrieved_docs[s_id].metadata:
del retrieved_docs[s_id].metadata["text_embedding"]
results.append(retrieved_docs[s_id])
return results
[docs] def get(
self,
ids: Optional[List[str]] = None,
text_in_page_content: Optional[str] = None,
meta_filter: Optional[dict] = None,
not_include_fields: Optional[Set[str]] = None,
limit: Optional[int] = None,
**kwargs: Any,
) -> Dict[str, Document]:
"""根据ids返回文档。
参数:
ids:嵌入向量的ids。
text_in_page_content:按文档的page_content中的文本进行过滤。
meta_filter:按文档的任何元数据进行过滤。
not_include_fields:不打包每个文档的指定字段。
limit:要返回的文档数量。默认为5。可选。
返回:
满足输入条件的文档。
"""
if self.awadb_client is None:
raise ValueError("AwaDB client is None!!!")
docs_detail = self.awadb_client.Get(
ids=ids,
text_in_page_content=text_in_page_content,
meta_filter=meta_filter,
not_include_fields=not_include_fields,
limit=limit,
)
results: Dict[str, Document] = {}
for doc_detail in docs_detail:
content = ""
meta_info = {}
for field in doc_detail:
if field == "embedding_text":
content = doc_detail[field]
continue
elif field == "text_embedding" or field == "_id":
continue
meta_info[field] = doc_detail[field]
doc = Document(page_content=content, metadata=meta_info)
results[doc_detail["_id"]] = doc
return results
[docs] def delete(
self,
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> Optional[bool]:
"""删除具有指定id的文档。
参数:
ids:嵌入向量的id。
**kwargs:子类可能使用的其他关键字参数。
返回:
Optional[bool]:如果删除成功则为True。
否则为False,如果未实现则为None。
"""
if self.awadb_client is None:
raise ValueError("AwaDB client is None!!!")
ret: Optional[bool] = None
if ids is None or ids.__len__() == 0:
return ret
ret = self.awadb_client.Delete(ids)
return ret
[docs] def update(
self,
ids: List[str],
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> List[str]:
"""更新具有指定ID的文档。
参数:
ids:要更新嵌入向量的ID列表。
texts:要更新文档的文本。
metadatas:要更新文档的元数据。
返回:
更新文档的ID。
"""
if self.awadb_client is None:
raise ValueError("AwaDB client is None!!!")
return self.awadb_client.UpdateTexts(
ids=ids, text_field_name="embedding_text", texts=texts, metadatas=metadatas
)
[docs] def create_table(
self,
table_name: str,
**kwargs: Any,
) -> bool:
"""创建一个新表。"""
if self.awadb_client is None:
return False
ret = self.awadb_client.Create(table_name)
if ret:
self.using_table_name = table_name
return ret
[docs] def use(
self,
table_name: str,
**kwargs: Any,
) -> bool:
"""使用指定的表。如果不知道表,请调用list_tables。"""
if self.awadb_client is None:
return False
ret = self.awadb_client.Use(table_name)
if ret:
self.using_table_name = table_name
return ret
[docs] def list_tables(
self,
**kwargs: Any,
) -> List[str]:
"""列出客户端创建的所有表格。"""
if self.awadb_client is None:
return []
return self.awadb_client.ListAllTables()
[docs] def get_current_table(
self,
**kwargs: Any,
) -> str:
"""获取当前表格。"""
return self.using_table_name
[docs] @classmethod
def from_texts(
cls: Type[AwaDB],
texts: List[str],
embedding: Optional[Embeddings] = None,
metadatas: Optional[List[dict]] = None,
table_name: str = _DEFAULT_TABLE_NAME,
log_and_data_dir: Optional[str] = None,
client: Optional[awadb.Client] = None,
**kwargs: Any,
) -> AwaDB:
"""从原始文档创建一个AwaDB向量存储。
参数:
texts (List[str]): 要添加到表中的文本列表。
embedding (Optional[Embeddings]): 嵌入函数。默认为None。
metadatas (Optional[List[dict]]): 元数据列表。默认为None。
table_name (str): 要创建的表的名称。
log_and_data_dir (Optional[str]): 日志和持久化目录。
client (Optional[awadb.Client]): AwaDB客户端
返回:
AwaDB: AwaDB向量存储。
"""
awadb_client = cls(
table_name=table_name,
embedding=embedding,
log_and_data_dir=log_and_data_dir,
client=client,
)
awadb_client.add_texts(texts=texts, metadatas=metadatas)
return awadb_client
[docs] @classmethod
def from_documents(
cls: Type[AwaDB],
documents: List[Document],
embedding: Optional[Embeddings] = None,
table_name: str = _DEFAULT_TABLE_NAME,
log_and_data_dir: Optional[str] = None,
client: Optional[awadb.Client] = None,
**kwargs: Any,
) -> AwaDB:
"""从文档列表创建一个AwaDB向量存储。
如果指定了log_and_data_dir,则表将持久化在那里。
参数:
documents (List[Document]): 要添加到向量存储的文档列表。
embedding (Optional[Embeddings]): 嵌入函数。默认为None。
table_name (str): 要创建的表的名称。
log_and_data_dir (Optional[str]): 持久化表的目录。
client (Optional[awadb.Client]): AwaDB客户端。
Any: 未来可能出现的任何参数
返回:
AwaDB: AwaDB向量存储。
"""
texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
return cls.from_texts(
texts=texts,
embedding=embedding,
metadatas=metadatas,
table_name=table_name,
log_and_data_dir=log_and_data_dir,
client=client,
)