"""封装了TileDB向量数据库。"""
from __future__ import annotations
import pickle
import random
import sys
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple
import numpy as np
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.utils import guard_import
from langchain_core.vectorstores import VectorStore
from langchain_community.vectorstores.utils import maximal_marginal_relevance
INDEX_METRICS = frozenset(["euclidean"])
DEFAULT_METRIC = "euclidean"
DOCUMENTS_ARRAY_NAME = "documents"
VECTOR_INDEX_NAME = "vectors"
MAX_UINT64 = np.iinfo(np.dtype("uint64")).max
MAX_FLOAT_32 = np.finfo(np.dtype("float32")).max
MAX_FLOAT = sys.float_info.max
[docs]def dependable_tiledb_import() -> Any:
"""如果可用,导入tiledb-vector-search,否则引发错误。"""
return (
guard_import("tiledb.vector_search"),
guard_import("tiledb"),
)
[docs]def get_vector_index_uri_from_group(group: Any) -> str:
"""获取向量索引的URI。"""
return group[VECTOR_INDEX_NAME].uri
[docs]def get_documents_array_uri_from_group(group: Any) -> str:
"""从group中获取文档数组的URI。
参数:
group: TileDB group对象。
返回:
文档数组的URI。
"""
return group[DOCUMENTS_ARRAY_NAME].uri
[docs]def get_vector_index_uri(uri: str) -> str:
"""获取向量索引的URI。"""
return f"{uri}/{VECTOR_INDEX_NAME}"
[docs]def get_documents_array_uri(uri: str) -> str:
"""获取文档数组的URI。"""
return f"{uri}/{DOCUMENTS_ARRAY_NAME}"
[docs]class TileDB(VectorStore):
"""TileDB向量存储。
要使用,您应该已安装``tiledb-vector-search`` python包。
示例:
.. code-block:: python
from langchain_community import TileDB
embeddings = OpenAIEmbeddings()
db = TileDB(embeddings, index_uri, metric)"""
[docs] def __init__(
self,
embedding: Embeddings,
index_uri: str,
metric: str,
*,
vector_index_uri: str = "",
docs_array_uri: str = "",
config: Optional[Mapping[str, Any]] = None,
timestamp: Any = None,
allow_dangerous_deserialization: bool = False,
**kwargs: Any,
):
"""初始化必要组件。
参数:
allow_dangerous_deserialization: 是否允许反序列化数据,涉及使用pickle加载数据。
数据可能被恶意用户修改,以传递恶意有效负载,导致在您的计算机上执行任意代码。
"""
if not allow_dangerous_deserialization:
raise ValueError(
"TileDB relies on pickle for serialization and deserialization. "
"This can be dangerous if the data is intercepted and/or modified "
"by malicious actors prior to being de-serialized. "
"If you are sure that the data is safe from modification, you can "
" set allow_dangerous_deserialization=True to proceed. "
"Loading of compromised data using pickle can result in execution of "
"arbitrary code on your machine."
)
self.embedding = embedding
self.embedding_function = embedding.embed_query
self.index_uri = index_uri
self.metric = metric
self.config = config
tiledb_vs, tiledb = (
guard_import("tiledb.vector_search"),
guard_import("tiledb"),
)
with tiledb.scope_ctx(ctx_or_config=config):
index_group = tiledb.Group(self.index_uri, "r")
self.vector_index_uri = (
vector_index_uri
if vector_index_uri != ""
else get_vector_index_uri_from_group(index_group)
)
self.docs_array_uri = (
docs_array_uri
if docs_array_uri != ""
else get_documents_array_uri_from_group(index_group)
)
index_group.close()
group = tiledb.Group(self.vector_index_uri, "r")
self.index_type = group.meta.get("index_type")
group.close()
self.timestamp = timestamp
if self.index_type == "FLAT":
self.vector_index = tiledb_vs.flat_index.FlatIndex(
uri=self.vector_index_uri,
config=self.config,
timestamp=self.timestamp,
**kwargs,
)
elif self.index_type == "IVF_FLAT":
self.vector_index = tiledb_vs.ivf_flat_index.IVFFlatIndex(
uri=self.vector_index_uri,
config=self.config,
timestamp=self.timestamp,
**kwargs,
)
@property
def embeddings(self) -> Optional[Embeddings]:
return self.embedding
[docs] def process_index_results(
self,
ids: List[int],
scores: List[float],
*,
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
score_threshold: float = MAX_FLOAT,
) -> List[Tuple[Document, float]]:
"""将TileDB的结果转换为文档列表和分数列表。
参数:
ids:文档在索引中的索引列表。
scores:文档在索引中的距离列表。
k:要返回的文档数量。默认为4。
filter(可选[Dict[str, Any]]):按元数据筛选。默认为None。
score_threshold:可选,一个浮点值,用于过滤检索到的文档集。
返回:
文档和分数的列表。
"""
tiledb = guard_import("tiledb")
docs = []
docs_array = tiledb.open(
self.docs_array_uri, "r", timestamp=self.timestamp, config=self.config
)
for idx, score in zip(ids, scores):
if idx == 0 and score == 0:
continue
if idx == MAX_UINT64 and score == MAX_FLOAT_32:
continue
doc = docs_array[idx]
if doc is None or len(doc["text"]) == 0:
raise ValueError(f"Could not find document for id {idx}, got {doc}")
pickled_metadata = doc.get("metadata")
result_doc = Document(page_content=str(doc["text"][0]))
if pickled_metadata is not None:
metadata = pickle.loads(
np.array(pickled_metadata.tolist()).astype(np.uint8).tobytes()
)
result_doc.metadata = metadata
if filter is not None:
filter = {
key: [value] if not isinstance(value, list) else value
for key, value in filter.items()
}
if all(
result_doc.metadata.get(key) in value
for key, value in filter.items()
):
docs.append((result_doc, score))
else:
docs.append((result_doc, score))
docs_array.close()
docs = [(doc, score) for doc, score in docs if score <= score_threshold]
return docs[:k]
[docs] def similarity_search_with_score_by_vector(
self,
embedding: List[float],
*,
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
fetch_k: int = 20,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""返回与查询最相似的文档。
参数:
embedding: 要查找相似文档的嵌入向量。
k: 要返回的文档数量。默认为4。
filter (Optional[Dict[str, Any]]): 按元数据过滤。默认为None。
fetch_k: (Optional[int]) 在过滤之前要获取的文档数量。
默认为20。
**kwargs: 要传递给相似性搜索的kwargs。可以包括:
nprobe: 可选,如果使用IVF_FLAT索引,则要检查的分区数
score_threshold: 可选,一个浮点值,用于过滤
检索到的文档集的结果
返回:
查询文本最相似的文档列表,以及每个文档的浮点距离。较低的分数表示更相似。
"""
if "score_threshold" in kwargs:
score_threshold = kwargs.pop("score_threshold")
else:
score_threshold = MAX_FLOAT
d, i = self.vector_index.query(
np.array([np.array(embedding).astype(np.float32)]).astype(np.float32),
k=k if filter is None else fetch_k,
**kwargs,
)
return self.process_index_results(
ids=i[0], scores=d[0], filter=filter, k=k, score_threshold=score_threshold
)
[docs] def similarity_search_with_score(
self,
query: str,
*,
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
fetch_k: int = 20,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""返回与查询最相似的文档。
参数:
query:要查找类似文档的文本。
k:要返回的文档数量。默认为4。
filter(可选[Dict[str,str]]):按元数据筛选。默认为无。
fetch_k:(可选[int])在过滤之前要获取的文档数量。
默认为20。
返回:
与查询文本最相似的文档列表,带有浮点距离。较低的分数表示更相似。
"""
embedding = self.embedding_function(query)
docs = self.similarity_search_with_score_by_vector(
embedding,
k=k,
filter=filter,
fetch_k=fetch_k,
**kwargs,
)
return docs
[docs] def similarity_search_by_vector(
self,
embedding: List[float],
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
fetch_k: int = 20,
**kwargs: Any,
) -> List[Document]:
"""返回与嵌入向量最相似的文档。
参数:
embedding: 要查找相似文档的嵌入。
k: 要返回的文档数量。默认为4。
filter(可选[Dict[str, str]]):按元数据过滤。默认为None。
fetch_k: (可选[int])在过滤之前要获取的文档数量。
默认为20。
返回:
与嵌入最相似的文档列表。
"""
docs_and_scores = self.similarity_search_with_score_by_vector(
embedding,
k=k,
filter=filter,
fetch_k=fetch_k,
**kwargs,
)
return [doc for doc, _ in docs_and_scores]
[docs] def similarity_search(
self,
query: str,
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
fetch_k: int = 20,
**kwargs: Any,
) -> List[Document]:
"""返回与查询最相似的文档。
参数:
query: 要查找相似文档的文本。
k: 要返回的文档数量。默认为4。
filter: (可选[Dict[str, str]]):按元数据筛选。默认为None。
fetch_k: (可选[int])在过滤之前要获取的文档数量。
默认为20。
返回:
与查询最相似的文档列表。
"""
docs_and_scores = self.similarity_search_with_score(
query, k=k, filter=filter, fetch_k=fetch_k, **kwargs
)
return [doc for doc, _ in docs_and_scores]
[docs] def max_marginal_relevance_search_with_score_by_vector(
self,
embedding: List[float],
*,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""使用最大边际相关性返回选定的文档及其相似性分数。
最大边际相关性优化了与查询的相似性和所选文档之间的多样性。
参数:
embedding:要查找相似文档的嵌入。
k:要返回的文档数量。默认为4。
fetch_k:在过滤到MMR算法之前要获取的文档数量。
lambda_mult:介于0和1之间的数字,确定结果之间多样性的程度,其中0对应于最大多样性,1对应于最小多样性。默认为0.5。
返回:
通过最大边际相关性选择的文档和相似性分数的列表,以及每个文档的分数。
"""
if "score_threshold" in kwargs:
score_threshold = kwargs.pop("score_threshold")
else:
score_threshold = MAX_FLOAT
scores, indices = self.vector_index.query(
np.array([np.array(embedding).astype(np.float32)]).astype(np.float32),
k=fetch_k if filter is None else fetch_k * 2,
**kwargs,
)
results = self.process_index_results(
ids=indices[0],
scores=scores[0],
filter=filter,
k=fetch_k if filter is None else fetch_k * 2,
score_threshold=score_threshold,
)
embeddings = [
self.embedding.embed_documents([doc.page_content])[0] for doc, _ in results
]
mmr_selected = maximal_marginal_relevance(
np.array([embedding], dtype=np.float32),
embeddings,
k=k,
lambda_mult=lambda_mult,
)
docs_and_scores = []
for i in mmr_selected:
docs_and_scores.append(results[i])
return docs_and_scores
[docs] def max_marginal_relevance_search_by_vector(
self,
embedding: List[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> List[Document]:
"""返回使用最大边际相关性选择的文档。
最大边际相关性优化了与查询的相似性和所选文档之间的多样性。
参数:
embedding: 要查找相似文档的嵌入。
k: 要返回的文档数量。默认为4。
fetch_k: 在过滤到传递给MMR算法之前要获取的文档数量。
lambda_mult: 介于0和1之间的数字,确定结果之间多样性的程度,其中0对应于最大多样性,1对应于最小多样性。默认为0.5。
返回:
由最大边际相关性选择的文档列表。
"""
docs_and_scores = self.max_marginal_relevance_search_with_score_by_vector(
embedding,
k=k,
fetch_k=fetch_k,
lambda_mult=lambda_mult,
filter=filter,
**kwargs,
)
return [doc for doc, _ in docs_and_scores]
[docs] def max_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> List[Document]:
"""返回使用最大边际相关性选择的文档。
最大边际相关性优化了与查询的相似性和所选文档之间的多样性。
参数:
query:要查找类似文档的文本。
k:要返回的文档数量。默认为4。
fetch_k:在过滤(如果需要)传递给MMR算法之前要获取的文档数量。
lambda_mult:0到1之间的数字,确定结果之间多样性的程度,0对应最大多样性,1对应最小多样性。默认为0.5。
返回:
通过最大边际相关性选择的文档列表。
"""
embedding = self.embedding_function(query)
docs = self.max_marginal_relevance_search_by_vector(
embedding,
k=k,
fetch_k=fetch_k,
lambda_mult=lambda_mult,
filter=filter,
**kwargs,
)
return docs
[docs] @classmethod
def create(
cls,
index_uri: str,
index_type: str,
dimensions: int,
vector_type: np.dtype,
*,
metadatas: bool = True,
config: Optional[Mapping[str, Any]] = None,
) -> None:
tiledb_vs, tiledb = (
guard_import("tiledb.vector_search"),
guard_import("tiledb"),
)
with tiledb.scope_ctx(ctx_or_config=config):
try:
tiledb.group_create(index_uri)
except tiledb.TileDBError as err:
raise err
group = tiledb.Group(index_uri, "w")
vector_index_uri = get_vector_index_uri(group.uri)
docs_uri = get_documents_array_uri(group.uri)
if index_type == "FLAT":
tiledb_vs.flat_index.create(
uri=vector_index_uri,
dimensions=dimensions,
vector_type=vector_type,
config=config,
)
elif index_type == "IVF_FLAT":
tiledb_vs.ivf_flat_index.create(
uri=vector_index_uri,
dimensions=dimensions,
vector_type=vector_type,
config=config,
)
group.add(vector_index_uri, name=VECTOR_INDEX_NAME)
# Create TileDB array to store Documents
# TODO add a Document store API to tiledb-vector-search to allow storing
# different types of objects and metadata in a more generic way.
dim = tiledb.Dim(
name="id",
domain=(0, MAX_UINT64 - 1),
dtype=np.dtype(np.uint64),
)
dom = tiledb.Domain(dim)
text_attr = tiledb.Attr(name="text", dtype=np.dtype("U1"), var=True)
attrs = [text_attr]
if metadatas:
metadata_attr = tiledb.Attr(name="metadata", dtype=np.uint8, var=True)
attrs.append(metadata_attr)
schema = tiledb.ArraySchema(
domain=dom,
sparse=True,
allows_duplicates=False,
attrs=attrs,
)
tiledb.Array.create(docs_uri, schema)
group.add(docs_uri, name=DOCUMENTS_ARRAY_NAME)
group.close()
@classmethod
def __from(
cls,
texts: List[str],
embeddings: List[List[float]],
embedding: Embeddings,
index_uri: str,
*,
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
metric: str = DEFAULT_METRIC,
index_type: str = "FLAT",
config: Optional[Mapping[str, Any]] = None,
index_timestamp: int = 0,
**kwargs: Any,
) -> TileDB:
if metric not in INDEX_METRICS:
raise ValueError(
(
f"Unsupported distance metric: {metric}. "
f"Expected one of {list(INDEX_METRICS)}"
)
)
tiledb_vs, tiledb = (
guard_import("tiledb.vector_search"),
guard_import("tiledb"),
)
input_vectors = np.array(embeddings).astype(np.float32)
cls.create(
index_uri=index_uri,
index_type=index_type,
dimensions=input_vectors.shape[1],
vector_type=input_vectors.dtype,
metadatas=metadatas is not None,
config=config,
)
with tiledb.scope_ctx(ctx_or_config=config):
if not embeddings:
raise ValueError("embeddings must be provided to build a TileDB index")
vector_index_uri = get_vector_index_uri(index_uri)
docs_uri = get_documents_array_uri(index_uri)
if ids is None:
ids = [str(random.randint(0, MAX_UINT64 - 1)) for _ in texts]
external_ids = np.array(ids).astype(np.uint64)
tiledb_vs.ingestion.ingest(
index_type=index_type,
index_uri=vector_index_uri,
input_vectors=input_vectors,
external_ids=external_ids,
index_timestamp=index_timestamp if index_timestamp != 0 else None,
config=config,
**kwargs,
)
with tiledb.open(docs_uri, "w") as A:
if external_ids is None:
external_ids = np.zeros(len(texts), dtype=np.uint64)
for i in range(len(texts)):
external_ids[i] = i
data = {}
data["text"] = np.array(texts)
if metadatas is not None:
metadata_attr = np.empty([len(metadatas)], dtype=object)
i = 0
for metadata in metadatas:
metadata_attr[i] = np.frombuffer(
pickle.dumps(metadata), dtype=np.uint8
)
i += 1
data["metadata"] = metadata_attr
A[external_ids] = data
return cls(
embedding=embedding,
index_uri=index_uri,
metric=metric,
config=config,
**kwargs,
)
[docs] def delete(
self, ids: Optional[List[str]] = None, timestamp: int = 0, **kwargs: Any
) -> Optional[bool]:
"""根据向量ID或其他条件删除。
参数:
ids:要删除的ID列表。
timestamp:可选的时间戳以删除。
**kwargs:子类可能使用的其他关键字参数。
返回:
Optional[bool]:如果删除成功则为True,否则为False,如果未实现则为None。
"""
external_ids = np.array(ids).astype(np.uint64)
self.vector_index.delete_batch(
external_ids=external_ids, timestamp=timestamp if timestamp != 0 else None
)
return True
[docs] def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
timestamp: int = 0,
**kwargs: Any,
) -> List[str]:
"""运行更多的文本通过嵌入并添加到向量存储。
参数:
texts:要添加到向量存储的字符串的可迭代对象。
metadatas:与文本相关的元数据的可选列表。
ids:每个文本对象的可选id。
timestamp:写入新文本的可选时间戳。
kwargs:向量存储特定参数
返回:
将文本添加到向量存储中的id列表。
"""
tiledb = guard_import("tiledb")
embeddings = self.embedding.embed_documents(list(texts))
if ids is None:
ids = [str(random.randint(0, MAX_UINT64 - 1)) for _ in texts]
external_ids = np.array(ids).astype(np.uint64)
vectors = np.empty((len(embeddings)), dtype="O")
for i in range(len(embeddings)):
vectors[i] = np.array(embeddings[i], dtype=np.float32)
self.vector_index.update_batch(
vectors=vectors,
external_ids=external_ids,
timestamp=timestamp if timestamp != 0 else None,
)
docs = {}
docs["text"] = np.array(texts)
if metadatas is not None:
metadata_attr = np.empty([len(metadatas)], dtype=object)
i = 0
for metadata in metadatas:
metadata_attr[i] = np.frombuffer(pickle.dumps(metadata), dtype=np.uint8)
i += 1
docs["metadata"] = metadata_attr
docs_array = tiledb.open(
self.docs_array_uri,
"w",
timestamp=timestamp if timestamp != 0 else None,
config=self.config,
)
docs_array[external_ids] = docs
docs_array.close()
return ids
[docs] @classmethod
def from_texts(
cls,
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
metric: str = DEFAULT_METRIC,
index_uri: str = "/tmp/tiledb_array",
index_type: str = "FLAT",
config: Optional[Mapping[str, Any]] = None,
index_timestamp: int = 0,
**kwargs: Any,
) -> TileDB:
"""从原始文档构建TileDB索引。
参数:
texts:要索引的文档列表。
embedding:要使用的嵌入函数。
metadatas:要与文档关联的元数据字典列表。
ids:每个文本对象的可选ID。
metric:用于索引的度量。默认为"euclidean"。
index_uri:要写入TileDB数组的URI。
index_type:可选,向量索引类型("FLAT",IVF_FLAT")。
config:可选,TileDB配置。
index_timestamp:可选,用于写入新文本的时间戳。
示例:
.. code-block:: python
from langchain_community import TileDB
from langchain_community.embeddings import OpenAIEmbeddings
embeddings = OpenAIEmbeddings()
index = TileDB.from_texts(texts, embeddings)
"""
embeddings = []
embeddings = embedding.embed_documents(texts)
return cls.__from(
texts=texts,
embeddings=embeddings,
embedding=embedding,
metadatas=metadatas,
ids=ids,
metric=metric,
index_uri=index_uri,
index_type=index_type,
config=config,
index_timestamp=index_timestamp,
**kwargs,
)
[docs] @classmethod
def from_embeddings(
cls,
text_embeddings: List[Tuple[str, List[float]]],
embedding: Embeddings,
index_uri: str,
*,
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
metric: str = DEFAULT_METRIC,
index_type: str = "FLAT",
config: Optional[Mapping[str, Any]] = None,
index_timestamp: int = 0,
**kwargs: Any,
) -> TileDB:
"""从嵌入构建TileDB索引。
参数:
text_embeddings:包含(文本,嵌入)元组的列表
embedding:要使用的嵌入函数。
index_uri:要写入TileDB数组的URI
metadatas:要与文档关联的元数据字典列表。
metric:可选,用于索引的度量。默认为"euclidean"。
index_type:可选,向量索引类型("FLAT",IVF_FLAT")
config:可选,TileDB配置
index_timestamp:可选,用于写入新文本的时间戳。
示例:
.. code-block:: python
from langchain_community import TileDB
from langchain_community.embeddings import OpenAIEmbeddings
embeddings = OpenAIEmbeddings()
text_embeddings = embeddings.embed_documents(texts)
text_embedding_pairs = list(zip(texts, text_embeddings))
db = TileDB.from_embeddings(text_embedding_pairs, embeddings)
"""
texts = [t[0] for t in text_embeddings]
embeddings = [t[1] for t in text_embeddings]
return cls.__from(
texts=texts,
embeddings=embeddings,
embedding=embedding,
metadatas=metadatas,
ids=ids,
metric=metric,
index_uri=index_uri,
index_type=index_type,
config=config,
index_timestamp=index_timestamp,
**kwargs,
)
[docs] @classmethod
def load(
cls,
index_uri: str,
embedding: Embeddings,
*,
metric: str = DEFAULT_METRIC,
config: Optional[Mapping[str, Any]] = None,
timestamp: Any = None,
**kwargs: Any,
) -> TileDB:
"""从URI加载TileDB索引。
参数:
index_uri:TileDB向量索引的URI。
embedding:生成查询时要使用的嵌入。
metric:可选,用于索引的度量。默认为"euclidean"。
config:可选的TileDB配置。
timestamp:可选的时间戳,用于打开数组。
"""
return cls(
embedding=embedding,
index_uri=index_uri,
metric=metric,
config=config,
timestamp=timestamp,
**kwargs,
)
[docs] def consolidate_updates(self, **kwargs: Any) -> None:
self.vector_index = self.vector_index.consolidate_updates(**kwargs)