Source code for langchain_community.vectorstores.tiledb

"""封装了TileDB向量数据库。"""

from __future__ import annotations

import pickle
import random
import sys
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple

import numpy as np
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.utils import guard_import
from langchain_core.vectorstores import VectorStore

from langchain_community.vectorstores.utils import maximal_marginal_relevance

INDEX_METRICS = frozenset(["euclidean"])
DEFAULT_METRIC = "euclidean"
DOCUMENTS_ARRAY_NAME = "documents"
VECTOR_INDEX_NAME = "vectors"
MAX_UINT64 = np.iinfo(np.dtype("uint64")).max
MAX_FLOAT_32 = np.finfo(np.dtype("float32")).max
MAX_FLOAT = sys.float_info.max


[docs]def dependable_tiledb_import() -> Any: """如果可用,导入tiledb-vector-search,否则引发错误。""" return ( guard_import("tiledb.vector_search"), guard_import("tiledb"), )
[docs]def get_vector_index_uri_from_group(group: Any) -> str: """获取向量索引的URI。""" return group[VECTOR_INDEX_NAME].uri
[docs]def get_documents_array_uri_from_group(group: Any) -> str: """从group中获取文档数组的URI。 参数: group: TileDB group对象。 返回: 文档数组的URI。 """ return group[DOCUMENTS_ARRAY_NAME].uri
[docs]def get_vector_index_uri(uri: str) -> str: """获取向量索引的URI。""" return f"{uri}/{VECTOR_INDEX_NAME}"
[docs]def get_documents_array_uri(uri: str) -> str: """获取文档数组的URI。""" return f"{uri}/{DOCUMENTS_ARRAY_NAME}"
[docs]class TileDB(VectorStore): """TileDB向量存储。 要使用,您应该已安装``tiledb-vector-search`` python包。 示例: .. code-block:: python from langchain_community import TileDB embeddings = OpenAIEmbeddings() db = TileDB(embeddings, index_uri, metric)"""
[docs] def __init__( self, embedding: Embeddings, index_uri: str, metric: str, *, vector_index_uri: str = "", docs_array_uri: str = "", config: Optional[Mapping[str, Any]] = None, timestamp: Any = None, allow_dangerous_deserialization: bool = False, **kwargs: Any, ): """初始化必要组件。 参数: allow_dangerous_deserialization: 是否允许反序列化数据,涉及使用pickle加载数据。 数据可能被恶意用户修改,以传递恶意有效负载,导致在您的计算机上执行任意代码。 """ if not allow_dangerous_deserialization: raise ValueError( "TileDB relies on pickle for serialization and deserialization. " "This can be dangerous if the data is intercepted and/or modified " "by malicious actors prior to being de-serialized. " "If you are sure that the data is safe from modification, you can " " set allow_dangerous_deserialization=True to proceed. " "Loading of compromised data using pickle can result in execution of " "arbitrary code on your machine." ) self.embedding = embedding self.embedding_function = embedding.embed_query self.index_uri = index_uri self.metric = metric self.config = config tiledb_vs, tiledb = ( guard_import("tiledb.vector_search"), guard_import("tiledb"), ) with tiledb.scope_ctx(ctx_or_config=config): index_group = tiledb.Group(self.index_uri, "r") self.vector_index_uri = ( vector_index_uri if vector_index_uri != "" else get_vector_index_uri_from_group(index_group) ) self.docs_array_uri = ( docs_array_uri if docs_array_uri != "" else get_documents_array_uri_from_group(index_group) ) index_group.close() group = tiledb.Group(self.vector_index_uri, "r") self.index_type = group.meta.get("index_type") group.close() self.timestamp = timestamp if self.index_type == "FLAT": self.vector_index = tiledb_vs.flat_index.FlatIndex( uri=self.vector_index_uri, config=self.config, timestamp=self.timestamp, **kwargs, ) elif self.index_type == "IVF_FLAT": self.vector_index = tiledb_vs.ivf_flat_index.IVFFlatIndex( uri=self.vector_index_uri, config=self.config, timestamp=self.timestamp, **kwargs, )
@property def embeddings(self) -> Optional[Embeddings]: return self.embedding
[docs] def process_index_results( self, ids: List[int], scores: List[float], *, k: int = 4, filter: Optional[Dict[str, Any]] = None, score_threshold: float = MAX_FLOAT, ) -> List[Tuple[Document, float]]: """将TileDB的结果转换为文档列表和分数列表。 参数: ids:文档在索引中的索引列表。 scores:文档在索引中的距离列表。 k:要返回的文档数量。默认为4。 filter(可选[Dict[str, Any]]):按元数据筛选。默认为None。 score_threshold:可选,一个浮点值,用于过滤检索到的文档集。 返回: 文档和分数的列表。 """ tiledb = guard_import("tiledb") docs = [] docs_array = tiledb.open( self.docs_array_uri, "r", timestamp=self.timestamp, config=self.config ) for idx, score in zip(ids, scores): if idx == 0 and score == 0: continue if idx == MAX_UINT64 and score == MAX_FLOAT_32: continue doc = docs_array[idx] if doc is None or len(doc["text"]) == 0: raise ValueError(f"Could not find document for id {idx}, got {doc}") pickled_metadata = doc.get("metadata") result_doc = Document(page_content=str(doc["text"][0])) if pickled_metadata is not None: metadata = pickle.loads( np.array(pickled_metadata.tolist()).astype(np.uint8).tobytes() ) result_doc.metadata = metadata if filter is not None: filter = { key: [value] if not isinstance(value, list) else value for key, value in filter.items() } if all( result_doc.metadata.get(key) in value for key, value in filter.items() ): docs.append((result_doc, score)) else: docs.append((result_doc, score)) docs_array.close() docs = [(doc, score) for doc, score in docs if score <= score_threshold] return docs[:k]
[docs] def similarity_search_with_score_by_vector( self, embedding: List[float], *, k: int = 4, filter: Optional[Dict[str, Any]] = None, fetch_k: int = 20, **kwargs: Any, ) -> List[Tuple[Document, float]]: """返回与查询最相似的文档。 参数: embedding: 要查找相似文档的嵌入向量。 k: 要返回的文档数量。默认为4。 filter (Optional[Dict[str, Any]]): 按元数据过滤。默认为None。 fetch_k: (Optional[int]) 在过滤之前要获取的文档数量。 默认为20。 **kwargs: 要传递给相似性搜索的kwargs。可以包括: nprobe: 可选,如果使用IVF_FLAT索引,则要检查的分区数 score_threshold: 可选,一个浮点值,用于过滤 检索到的文档集的结果 返回: 查询文本最相似的文档列表,以及每个文档的浮点距离。较低的分数表示更相似。 """ if "score_threshold" in kwargs: score_threshold = kwargs.pop("score_threshold") else: score_threshold = MAX_FLOAT d, i = self.vector_index.query( np.array([np.array(embedding).astype(np.float32)]).astype(np.float32), k=k if filter is None else fetch_k, **kwargs, ) return self.process_index_results( ids=i[0], scores=d[0], filter=filter, k=k, score_threshold=score_threshold )
[docs] def similarity_search_with_score( self, query: str, *, k: int = 4, filter: Optional[Dict[str, Any]] = None, fetch_k: int = 20, **kwargs: Any, ) -> List[Tuple[Document, float]]: """返回与查询最相似的文档。 参数: query:要查找类似文档的文本。 k:要返回的文档数量。默认为4。 filter(可选[Dict[str,str]]):按元数据筛选。默认为无。 fetch_k:(可选[int])在过滤之前要获取的文档数量。 默认为20。 返回: 与查询文本最相似的文档列表,带有浮点距离。较低的分数表示更相似。 """ embedding = self.embedding_function(query) docs = self.similarity_search_with_score_by_vector( embedding, k=k, filter=filter, fetch_k=fetch_k, **kwargs, ) return docs
[docs] def similarity_search_by_vector( self, embedding: List[float], k: int = 4, filter: Optional[Dict[str, Any]] = None, fetch_k: int = 20, **kwargs: Any, ) -> List[Document]: """返回与嵌入向量最相似的文档。 参数: embedding: 要查找相似文档的嵌入。 k: 要返回的文档数量。默认为4。 filter(可选[Dict[str, str]]):按元数据过滤。默认为None。 fetch_k: (可选[int])在过滤之前要获取的文档数量。 默认为20。 返回: 与嵌入最相似的文档列表。 """ docs_and_scores = self.similarity_search_with_score_by_vector( embedding, k=k, filter=filter, fetch_k=fetch_k, **kwargs, ) return [doc for doc, _ in docs_and_scores]
[docs] def max_marginal_relevance_search_with_score_by_vector( self, embedding: List[float], *, k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, filter: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> List[Tuple[Document, float]]: """使用最大边际相关性返回选定的文档及其相似性分数。 最大边际相关性优化了与查询的相似性和所选文档之间的多样性。 参数: embedding:要查找相似文档的嵌入。 k:要返回的文档数量。默认为4。 fetch_k:在过滤到MMR算法之前要获取的文档数量。 lambda_mult:介于0和1之间的数字,确定结果之间多样性的程度,其中0对应于最大多样性,1对应于最小多样性。默认为0.5。 返回: 通过最大边际相关性选择的文档和相似性分数的列表,以及每个文档的分数。 """ if "score_threshold" in kwargs: score_threshold = kwargs.pop("score_threshold") else: score_threshold = MAX_FLOAT scores, indices = self.vector_index.query( np.array([np.array(embedding).astype(np.float32)]).astype(np.float32), k=fetch_k if filter is None else fetch_k * 2, **kwargs, ) results = self.process_index_results( ids=indices[0], scores=scores[0], filter=filter, k=fetch_k if filter is None else fetch_k * 2, score_threshold=score_threshold, ) embeddings = [ self.embedding.embed_documents([doc.page_content])[0] for doc, _ in results ] mmr_selected = maximal_marginal_relevance( np.array([embedding], dtype=np.float32), embeddings, k=k, lambda_mult=lambda_mult, ) docs_and_scores = [] for i in mmr_selected: docs_and_scores.append(results[i]) return docs_and_scores
[docs] def max_marginal_relevance_search_by_vector( self, embedding: List[float], k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, filter: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> List[Document]: """返回使用最大边际相关性选择的文档。 最大边际相关性优化了与查询的相似性和所选文档之间的多样性。 参数: embedding: 要查找相似文档的嵌入。 k: 要返回的文档数量。默认为4。 fetch_k: 在过滤到传递给MMR算法之前要获取的文档数量。 lambda_mult: 介于0和1之间的数字,确定结果之间多样性的程度,其中0对应于最大多样性,1对应于最小多样性。默认为0.5。 返回: 由最大边际相关性选择的文档列表。 """ docs_and_scores = self.max_marginal_relevance_search_with_score_by_vector( embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, filter=filter, **kwargs, ) return [doc for doc, _ in docs_and_scores]
[docs] @classmethod def create( cls, index_uri: str, index_type: str, dimensions: int, vector_type: np.dtype, *, metadatas: bool = True, config: Optional[Mapping[str, Any]] = None, ) -> None: tiledb_vs, tiledb = ( guard_import("tiledb.vector_search"), guard_import("tiledb"), ) with tiledb.scope_ctx(ctx_or_config=config): try: tiledb.group_create(index_uri) except tiledb.TileDBError as err: raise err group = tiledb.Group(index_uri, "w") vector_index_uri = get_vector_index_uri(group.uri) docs_uri = get_documents_array_uri(group.uri) if index_type == "FLAT": tiledb_vs.flat_index.create( uri=vector_index_uri, dimensions=dimensions, vector_type=vector_type, config=config, ) elif index_type == "IVF_FLAT": tiledb_vs.ivf_flat_index.create( uri=vector_index_uri, dimensions=dimensions, vector_type=vector_type, config=config, ) group.add(vector_index_uri, name=VECTOR_INDEX_NAME) # Create TileDB array to store Documents # TODO add a Document store API to tiledb-vector-search to allow storing # different types of objects and metadata in a more generic way. dim = tiledb.Dim( name="id", domain=(0, MAX_UINT64 - 1), dtype=np.dtype(np.uint64), ) dom = tiledb.Domain(dim) text_attr = tiledb.Attr(name="text", dtype=np.dtype("U1"), var=True) attrs = [text_attr] if metadatas: metadata_attr = tiledb.Attr(name="metadata", dtype=np.uint8, var=True) attrs.append(metadata_attr) schema = tiledb.ArraySchema( domain=dom, sparse=True, allows_duplicates=False, attrs=attrs, ) tiledb.Array.create(docs_uri, schema) group.add(docs_uri, name=DOCUMENTS_ARRAY_NAME) group.close()
@classmethod def __from( cls, texts: List[str], embeddings: List[List[float]], embedding: Embeddings, index_uri: str, *, metadatas: Optional[List[dict]] = None, ids: Optional[List[str]] = None, metric: str = DEFAULT_METRIC, index_type: str = "FLAT", config: Optional[Mapping[str, Any]] = None, index_timestamp: int = 0, **kwargs: Any, ) -> TileDB: if metric not in INDEX_METRICS: raise ValueError( ( f"Unsupported distance metric: {metric}. " f"Expected one of {list(INDEX_METRICS)}" ) ) tiledb_vs, tiledb = ( guard_import("tiledb.vector_search"), guard_import("tiledb"), ) input_vectors = np.array(embeddings).astype(np.float32) cls.create( index_uri=index_uri, index_type=index_type, dimensions=input_vectors.shape[1], vector_type=input_vectors.dtype, metadatas=metadatas is not None, config=config, ) with tiledb.scope_ctx(ctx_or_config=config): if not embeddings: raise ValueError("embeddings must be provided to build a TileDB index") vector_index_uri = get_vector_index_uri(index_uri) docs_uri = get_documents_array_uri(index_uri) if ids is None: ids = [str(random.randint(0, MAX_UINT64 - 1)) for _ in texts] external_ids = np.array(ids).astype(np.uint64) tiledb_vs.ingestion.ingest( index_type=index_type, index_uri=vector_index_uri, input_vectors=input_vectors, external_ids=external_ids, index_timestamp=index_timestamp if index_timestamp != 0 else None, config=config, **kwargs, ) with tiledb.open(docs_uri, "w") as A: if external_ids is None: external_ids = np.zeros(len(texts), dtype=np.uint64) for i in range(len(texts)): external_ids[i] = i data = {} data["text"] = np.array(texts) if metadatas is not None: metadata_attr = np.empty([len(metadatas)], dtype=object) i = 0 for metadata in metadatas: metadata_attr[i] = np.frombuffer( pickle.dumps(metadata), dtype=np.uint8 ) i += 1 data["metadata"] = metadata_attr A[external_ids] = data return cls( embedding=embedding, index_uri=index_uri, metric=metric, config=config, **kwargs, )
[docs] def delete( self, ids: Optional[List[str]] = None, timestamp: int = 0, **kwargs: Any ) -> Optional[bool]: """根据向量ID或其他条件删除。 参数: ids:要删除的ID列表。 timestamp:可选的时间戳以删除。 **kwargs:子类可能使用的其他关键字参数。 返回: Optional[bool]:如果删除成功则为True,否则为False,如果未实现则为None。 """ external_ids = np.array(ids).astype(np.uint64) self.vector_index.delete_batch( external_ids=external_ids, timestamp=timestamp if timestamp != 0 else None ) return True
[docs] def add_texts( self, texts: Iterable[str], metadatas: Optional[List[dict]] = None, ids: Optional[List[str]] = None, timestamp: int = 0, **kwargs: Any, ) -> List[str]: """运行更多的文本通过嵌入并添加到向量存储。 参数: texts:要添加到向量存储的字符串的可迭代对象。 metadatas:与文本相关的元数据的可选列表。 ids:每个文本对象的可选id。 timestamp:写入新文本的可选时间戳。 kwargs:向量存储特定参数 返回: 将文本添加到向量存储中的id列表。 """ tiledb = guard_import("tiledb") embeddings = self.embedding.embed_documents(list(texts)) if ids is None: ids = [str(random.randint(0, MAX_UINT64 - 1)) for _ in texts] external_ids = np.array(ids).astype(np.uint64) vectors = np.empty((len(embeddings)), dtype="O") for i in range(len(embeddings)): vectors[i] = np.array(embeddings[i], dtype=np.float32) self.vector_index.update_batch( vectors=vectors, external_ids=external_ids, timestamp=timestamp if timestamp != 0 else None, ) docs = {} docs["text"] = np.array(texts) if metadatas is not None: metadata_attr = np.empty([len(metadatas)], dtype=object) i = 0 for metadata in metadatas: metadata_attr[i] = np.frombuffer(pickle.dumps(metadata), dtype=np.uint8) i += 1 docs["metadata"] = metadata_attr docs_array = tiledb.open( self.docs_array_uri, "w", timestamp=timestamp if timestamp != 0 else None, config=self.config, ) docs_array[external_ids] = docs docs_array.close() return ids
[docs] @classmethod def from_texts( cls, texts: List[str], embedding: Embeddings, metadatas: Optional[List[dict]] = None, ids: Optional[List[str]] = None, metric: str = DEFAULT_METRIC, index_uri: str = "/tmp/tiledb_array", index_type: str = "FLAT", config: Optional[Mapping[str, Any]] = None, index_timestamp: int = 0, **kwargs: Any, ) -> TileDB: """从原始文档构建TileDB索引。 参数: texts:要索引的文档列表。 embedding:要使用的嵌入函数。 metadatas:要与文档关联的元数据字典列表。 ids:每个文本对象的可选ID。 metric:用于索引的度量。默认为"euclidean"。 index_uri:要写入TileDB数组的URI。 index_type:可选,向量索引类型("FLAT",IVF_FLAT")。 config:可选,TileDB配置。 index_timestamp:可选,用于写入新文本的时间戳。 示例: .. code-block:: python from langchain_community import TileDB from langchain_community.embeddings import OpenAIEmbeddings embeddings = OpenAIEmbeddings() index = TileDB.from_texts(texts, embeddings) """ embeddings = [] embeddings = embedding.embed_documents(texts) return cls.__from( texts=texts, embeddings=embeddings, embedding=embedding, metadatas=metadatas, ids=ids, metric=metric, index_uri=index_uri, index_type=index_type, config=config, index_timestamp=index_timestamp, **kwargs, )
[docs] @classmethod def from_embeddings( cls, text_embeddings: List[Tuple[str, List[float]]], embedding: Embeddings, index_uri: str, *, metadatas: Optional[List[dict]] = None, ids: Optional[List[str]] = None, metric: str = DEFAULT_METRIC, index_type: str = "FLAT", config: Optional[Mapping[str, Any]] = None, index_timestamp: int = 0, **kwargs: Any, ) -> TileDB: """从嵌入构建TileDB索引。 参数: text_embeddings:包含(文本,嵌入)元组的列表 embedding:要使用的嵌入函数。 index_uri:要写入TileDB数组的URI metadatas:要与文档关联的元数据字典列表。 metric:可选,用于索引的度量。默认为"euclidean"。 index_type:可选,向量索引类型("FLAT",IVF_FLAT") config:可选,TileDB配置 index_timestamp:可选,用于写入新文本的时间戳。 示例: .. code-block:: python from langchain_community import TileDB from langchain_community.embeddings import OpenAIEmbeddings embeddings = OpenAIEmbeddings() text_embeddings = embeddings.embed_documents(texts) text_embedding_pairs = list(zip(texts, text_embeddings)) db = TileDB.from_embeddings(text_embedding_pairs, embeddings) """ texts = [t[0] for t in text_embeddings] embeddings = [t[1] for t in text_embeddings] return cls.__from( texts=texts, embeddings=embeddings, embedding=embedding, metadatas=metadatas, ids=ids, metric=metric, index_uri=index_uri, index_type=index_type, config=config, index_timestamp=index_timestamp, **kwargs, )
[docs] @classmethod def load( cls, index_uri: str, embedding: Embeddings, *, metric: str = DEFAULT_METRIC, config: Optional[Mapping[str, Any]] = None, timestamp: Any = None, **kwargs: Any, ) -> TileDB: """从URI加载TileDB索引。 参数: index_uri:TileDB向量索引的URI。 embedding:生成查询时要使用的嵌入。 metric:可选,用于索引的度量。默认为"euclidean"。 config:可选的TileDB配置。 timestamp:可选的时间戳,用于打开数组。 """ return cls( embedding=embedding, index_uri=index_uri, metric=metric, config=config, timestamp=timestamp, **kwargs, )
[docs] def consolidate_updates(self, **kwargs: Any) -> None: self.vector_index = self.vector_index.consolidate_updates(**kwargs)