Source code for langchain_community.vectorstores.astradb

from __future__ import annotations

import uuid
import warnings
from concurrent.futures import ThreadPoolExecutor
from typing import (
    TYPE_CHECKING,
    Any,
    Awaitable,
    Callable,
    Dict,
    Iterable,
    List,
    Optional,
    Set,
    Tuple,
    Type,
    TypeVar,
    Union,
)

import numpy as np
from langchain_core._api.deprecation import deprecated
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.runnables.utils import gather_with_concurrency
from langchain_core.utils.iter import batch_iterate
from langchain_core.vectorstores import VectorStore

from langchain_community.utilities.astradb import (
    SetupMode,
    _AstraDBCollectionEnvironment,
)
from langchain_community.vectorstores.utils import maximal_marginal_relevance

if TYPE_CHECKING:
    from astrapy.db import AstraDB as LibAstraDB
    from astrapy.db import AsyncAstraDB

ADBVST = TypeVar("ADBVST", bound="AstraDB")
T = TypeVar("T")
U = TypeVar("U")
DocDict = Dict[str, Any]  # dicts expressing entries to insert

# Batch/concurrency default values (if parameters not provided):
# Size of batches for bulk insertions:
#   (20 is the max batch size for the HTTP API at the time of writing)
DEFAULT_BATCH_SIZE = 20
# Number of threads to insert batches concurrently:
DEFAULT_BULK_INSERT_BATCH_CONCURRENCY = 16
# Number of threads in a batch to insert pre-existing entries:
DEFAULT_BULK_INSERT_OVERWRITE_CONCURRENCY = 10
# Number of threads (for deleting multiple rows concurrently):
DEFAULT_BULK_DELETE_CONCURRENCY = 20


def _unique_list(lst: List[T], key: Callable[[T], U]) -> List[T]:
    visited_keys: Set[U] = set()
    new_lst = []
    for item in lst:
        item_key = key(item)
        if item_key not in visited_keys:
            visited_keys.add(item_key)
            new_lst.append(item)
    return new_lst


[docs]@deprecated( since="0.0.21", removal="0.3.0", alternative_import="langchain_astradb.AstraDBVectorStore", ) class AstraDB(VectorStore): @staticmethod def _filter_to_metadata(filter_dict: Optional[Dict[str, Any]]) -> Dict[str, Any]: if filter_dict is None: return {} else: metadata_filter = {} for k, v in filter_dict.items(): if k and k[0] == "$": if isinstance(v, list): metadata_filter[k] = [AstraDB._filter_to_metadata(f) for f in v] else: metadata_filter[k] = AstraDB._filter_to_metadata(v) # type: ignore[assignment] else: metadata_filter[f"metadata.{k}"] = v return metadata_filter
[docs] def __init__( self, *, embedding: Embeddings, collection_name: str, token: Optional[str] = None, api_endpoint: Optional[str] = None, astra_db_client: Optional[LibAstraDB] = None, async_astra_db_client: Optional[AsyncAstraDB] = None, namespace: Optional[str] = None, metric: Optional[str] = None, batch_size: Optional[int] = None, bulk_insert_batch_concurrency: Optional[int] = None, bulk_insert_overwrite_concurrency: Optional[int] = None, bulk_delete_concurrency: Optional[int] = None, setup_mode: SetupMode = SetupMode.SYNC, pre_delete_collection: bool = False, ) -> None: """封装了用于向量存储工作负载的 DataStax Astra DB。 要快速开始并了解详情,请访问 https://docs.datastax.com/en/astra/astra-db-vector/ 示例: .. code-block:: python from langchain_community.vectorstores import AstraDB from langchain_openai.embeddings import OpenAIEmbeddings embeddings = OpenAIEmbeddings() vectorstore = AstraDB( embedding=embeddings, collection_name="my_store", token="AstraCS:...", api_endpoint="https://<DB-ID>-<REGION>.apps.astra.datastax.com" ) vectorstore.add_texts(["Giraffes", "All good here"]) results = vectorstore.similarity_search("Everything's ok", k=1) 参数: embedding: 要使用的嵌入函数。 collection_name: 要创建/使用的 Astra DB 集合的名称。 token: 用于 Astra DB 使用的 API 令牌。 api_endpoint: API 端点的完整 URL,例如 `https://<DB-ID>-us-east1.apps.astra.datastax.com`。 astra_db_client: *token+api_endpoint 的替代方案*, 您可以传递一个已创建的 'astrapy.db.AstraDB' 实例。 async_astra_db_client: *token+api_endpoint 的替代方案*, 您可以传递一个已创建的 'astrapy.db.AsyncAstraDB' 实例。 namespace: 创建集合的命名空间(又名 keyspace)。 默认为数据库的 "default namespace"。 metric: 要在 Astra DB 中使用的相似性函数之一。 如果省略,将使用 Astra DB API 的默认值(即 "cosine" - 但是, 出于性能原因,如果嵌入已归一化为一,建议使用 "dot_product")。 batch_size: 用于批量插入的批量大小。 bulk_insert_batch_concurrency: 并发插入批次的线程数或协程数。 bulk_insert_overwrite_concurrency: 在批次中插入已存在条目的线程数或协程数。 bulk_delete_concurrency: 并发删除多行的线程数。 pre_delete_collection: 是否在创建集合之前删除集合。 如果为 False 并且集合已存在,则将使用该集合。 注意: 对于同步 :meth:`~add_texts` 中的并发性,根据经验,在典型客户端机器上, 建议将数量 bulk_insert_batch_concurrency * bulk_insert_overwrite_concurrency 保持远低于 1000,以避免耗尽客户端的多线程/网络资源。 预设的默认值相对保守,以满足大多数机器的规格,但一个明智的选择可能是: - bulk_insert_batch_concurrency = 80 - bulk_insert_overwrite_concurrency = 10 需要进行一些实验来找到最佳结果, 这取决于机器/网络规格以及预期工作负载(特别是写入是现有 id 的更新的频率)。 请记住,您也可以将并发设置传递给 :meth:`~add_texts` 和 :meth:`~add_documents` 的各个调用。 """ self.embedding = embedding self.collection_name = collection_name self.token = token self.api_endpoint = api_endpoint self.namespace = namespace # Concurrency settings self.batch_size: int = batch_size or DEFAULT_BATCH_SIZE self.bulk_insert_batch_concurrency: int = ( bulk_insert_batch_concurrency or DEFAULT_BULK_INSERT_BATCH_CONCURRENCY ) self.bulk_insert_overwrite_concurrency: int = ( bulk_insert_overwrite_concurrency or DEFAULT_BULK_INSERT_OVERWRITE_CONCURRENCY ) self.bulk_delete_concurrency: int = ( bulk_delete_concurrency or DEFAULT_BULK_DELETE_CONCURRENCY ) # "vector-related" settings self.metric = metric embedding_dimension: Union[int, Awaitable[int], None] = None if setup_mode == SetupMode.ASYNC: embedding_dimension = self._aget_embedding_dimension() elif setup_mode == SetupMode.SYNC: embedding_dimension = self._get_embedding_dimension() self.astra_env = _AstraDBCollectionEnvironment( collection_name=collection_name, token=token, api_endpoint=api_endpoint, astra_db_client=astra_db_client, async_astra_db_client=async_astra_db_client, namespace=namespace, setup_mode=setup_mode, pre_delete_collection=pre_delete_collection, embedding_dimension=embedding_dimension, metric=metric, ) self.astra_db = self.astra_env.astra_db self.async_astra_db = self.astra_env.async_astra_db self.collection = self.astra_env.collection self.async_collection = self.astra_env.async_collection
def _get_embedding_dimension(self) -> int: return len(self.embedding.embed_query(text="This is a sample sentence.")) async def _aget_embedding_dimension(self) -> int: return len(await self.embedding.aembed_query(text="This is a sample sentence.")) @property def embeddings(self) -> Embeddings: return self.embedding @staticmethod def _dont_flip_the_cos_score(similarity0to1: float) -> float: """保持客户端的相似度不变,因为它已经在[0:1]范围内。""" return similarity0to1 def _select_relevance_score_fn(self) -> Callable[[float], float]: """基础的API调用已经返回了一个“适当的分数”,即在[0, 1]之间,数值越高表示*更相似*,因此这里的最终分数转换不会颠倒区间: """ return self._dont_flip_the_cos_score
[docs] def clear(self) -> None: """清空集合中存储的所有条目。""" self.astra_env.ensure_db_setup() self.collection.delete_many({})
[docs] async def aclear(self) -> None: """清空集合中存储的所有条目。""" await self.astra_env.aensure_db_setup() await self.async_collection.delete_many({}) # type: ignore[union-attr]
[docs] def delete_by_document_id(self, document_id: str) -> bool: """从存储中删除单个文档,给定其文档ID。 参数: document_id:文档ID 返回值: 如果确实已删除文档,则为True,如果未找到ID,则为False。 """ self.astra_env.ensure_db_setup() deletion_response = self.collection.delete_one(document_id) # type: ignore[union-attr] return ((deletion_response or {}).get("status") or {}).get( "deletedCount", 0 ) == 1
[docs] async def adelete_by_document_id(self, document_id: str) -> bool: """从存储中删除单个文档,给定其文档ID。 参数: document_id:文档ID 返回值: 如果确实已删除文档,则为True,如果未找到ID,则为False。 """ await self.astra_env.aensure_db_setup() deletion_response = await self.async_collection.delete_one(document_id) return ((deletion_response or {}).get("status") or {}).get( "deletedCount", 0 ) == 1
[docs] def delete( self, ids: Optional[List[str]] = None, concurrency: Optional[int] = None, **kwargs: Any, ) -> Optional[bool]: """根据向量ID删除。 参数: ids:要删除的ID列表。 并发性:发出单个文档删除请求的线程的最大数量。 默认为实例级设置。 返回: 如果删除成功则为True,否则为False。 """ if kwargs: warnings.warn( "Method 'delete' of AstraDB vector store invoked with " f"unsupported arguments ({', '.join(sorted(kwargs.keys()))}), " "which will be ignored." ) if ids is None: raise ValueError("No ids provided to delete.") _max_workers = concurrency or self.bulk_delete_concurrency with ThreadPoolExecutor(max_workers=_max_workers) as tpe: _ = list( tpe.map( self.delete_by_document_id, ids, ) ) return True
[docs] async def adelete( self, ids: Optional[List[str]] = None, concurrency: Optional[int] = None, **kwargs: Any, ) -> Optional[bool]: """根据向量ID删除。 参数: ids:要删除的ID列表。 concurrency:单个文档删除请求的最大并发数。 默认为实例级设置。 **kwargs:子类可能使用的其他关键字参数。 返回: 如果删除成功则为True,否则为False。 """ if kwargs: warnings.warn( "Method 'adelete' of AstraDB vector store invoked with " f"unsupported arguments ({', '.join(sorted(kwargs.keys()))}), " "which will be ignored." ) if ids is None: raise ValueError("No ids provided to delete.") return all( await gather_with_concurrency( concurrency, *[self.adelete_by_document_id(doc_id) for doc_id in ids] ) )
[docs] def delete_collection(self) -> None: """完全从数据库中删除集合(与:meth:`~clear`相反,后者仅清空集合)。 存储的数据将丢失且无法恢复,资源将被释放。 请谨慎使用。 """ self.astra_env.ensure_db_setup() self.astra_db.delete_collection( collection_name=self.collection_name, )
[docs] async def adelete_collection(self) -> None: """完全从数据库中删除集合(与:meth:`~aclear`相反,后者仅清空集合)。 存储的数据将丢失且无法恢复,资源将被释放。 请谨慎使用。 """ await self.astra_env.aensure_db_setup() await self.async_astra_db.delete_collection( collection_name=self.collection_name, )
@staticmethod def _get_documents_to_insert( texts: Iterable[str], embedding_vectors: List[List[float]], metadatas: Optional[List[dict]] = None, ids: Optional[List[str]] = None, ) -> List[DocDict]: if ids is None: ids = [uuid.uuid4().hex for _ in texts] if metadatas is None: metadatas = [{} for _ in texts] # documents_to_insert = [ { "content": b_txt, "_id": b_id, "$vector": b_emb, "metadata": b_md, } for b_txt, b_emb, b_id, b_md in zip( texts, embedding_vectors, ids, metadatas, ) ] # make unique by id, keeping the last uniqued_documents_to_insert = _unique_list( documents_to_insert[::-1], lambda document: document["_id"], )[::-1] return uniqued_documents_to_insert @staticmethod def _get_missing_from_batch( document_batch: List[DocDict], insert_result: Dict[str, Any] ) -> Tuple[List[str], List[DocDict]]: if "status" not in insert_result: raise ValueError( f"API Exception while running bulk insertion: {str(insert_result)}" ) batch_inserted = insert_result["status"]["insertedIds"] # estimation of the preexisting documents that failed missed_inserted_ids = {document["_id"] for document in document_batch} - set( batch_inserted ) errors = insert_result.get("errors", []) # careful for other sources of error other than "doc already exists" num_errors = len(errors) unexpected_errors = any( error.get("errorCode") != "DOCUMENT_ALREADY_EXISTS" for error in errors ) if num_errors != len(missed_inserted_ids) or unexpected_errors: raise ValueError( f"API Exception while running bulk insertion: {str(errors)}" ) # deal with the missing insertions as upserts missing_from_batch = [ document for document in document_batch if document["_id"] in missed_inserted_ids ] return batch_inserted, missing_from_batch
[docs] def add_texts( self, texts: Iterable[str], metadatas: Optional[List[dict]] = None, ids: Optional[List[str]] = None, *, batch_size: Optional[int] = None, batch_concurrency: Optional[int] = None, overwrite_concurrency: Optional[int] = None, **kwargs: Any, ) -> List[str]: """通过嵌入将文本传递并将其添加到向量存储中。 如果传递了显式的ids,那些id已经在存储中的条目将被替换。 参数: texts:要添加到向量存储中的文本。 metadatas:元数据的可选列表。 ids:id的可选列表。 batch_size:每个API调用中的文档数量。 检查底层Astra DB HTTP API规范以获取最大值 (在编写本文时为20)。如果未提供,默认为 实例级别设置。 batch_concurrency:处理插入批次的线程数 并发。如果未提供,默认为实例级别 设置。 overwrite_concurrency:处理线程数 每个批次中的现有文档(需要单独的API调用)。 如果未提供,默认为实例级别设置。 注意: 元数据字典中允许的字段名称有限制 来自底层Astra DB API。例如,`$`(美元符号)不能在字典键中使用。 有关详细信息,请参阅此文档: https://docs.datastax.com/en/astra/astra-db-vector/api-reference/data-api.html 返回: 添加的文本的id列表。 """ if kwargs: warnings.warn( "Method 'add_texts' of AstraDB vector store invoked with " f"unsupported arguments ({', '.join(sorted(kwargs.keys()))}), " "which will be ignored." ) self.astra_env.ensure_db_setup() embedding_vectors = self.embedding.embed_documents(list(texts)) documents_to_insert = self._get_documents_to_insert( texts, embedding_vectors, metadatas, ids ) def _handle_batch(document_batch: List[DocDict]) -> List[str]: im_result = self.collection.insert_many( documents=document_batch, options={"ordered": False}, partial_failures_allowed=True, ) batch_inserted, missing_from_batch = self._get_missing_from_batch( document_batch, im_result ) def _handle_missing_document(missing_document: DocDict) -> str: replacement_result = self.collection.find_one_and_replace( filter={"_id": missing_document["_id"]}, replacement=missing_document, ) return replacement_result["data"]["document"]["_id"] _u_max_workers = ( overwrite_concurrency or self.bulk_insert_overwrite_concurrency ) with ThreadPoolExecutor(max_workers=_u_max_workers) as tpe2: batch_replaced = list( tpe2.map( _handle_missing_document, missing_from_batch, ) ) return batch_inserted + batch_replaced _b_max_workers = batch_concurrency or self.bulk_insert_batch_concurrency with ThreadPoolExecutor(max_workers=_b_max_workers) as tpe: all_ids_nested = tpe.map( _handle_batch, batch_iterate( batch_size or self.batch_size, documents_to_insert, ), ) return [iid for id_list in all_ids_nested for iid in id_list]
[docs] async def aadd_texts( self, texts: Iterable[str], metadatas: Optional[List[dict]] = None, ids: Optional[List[str]] = None, *, batch_size: Optional[int] = None, batch_concurrency: Optional[int] = None, overwrite_concurrency: Optional[int] = None, **kwargs: Any, ) -> List[str]: """通过嵌入将文本传递并将其添加到向量存储中。 如果传递了显式的ids,那些id已经在存储中的条目将被替换。 参数: texts:要添加到向量存储中的文本。 metadatas:元数据的可选列表。 ids:id的可选列表。 batch_size:每个API调用中的文档数量。 检查底层Astra DB HTTP API规范以获取最大值 (在编写本文时为20)。如果未提供,默认为 实例级别设置。 batch_concurrency:处理插入批次的线程数 并发。如果未提供,默认为实例级别 设置。 overwrite_concurrency:处理线程数 每个批次中的现有文档(需要单独的API调用)。 如果未提供,默认为实例级别设置。 注意: 元数据字典中允许的字段名称有限制 来自底层Astra DB API。例如,`$`(美元符号)不能在字典键中使用。 有关详细信息,请参阅此文档: https://docs.datastax.com/en/astra/astra-db-vector/api-reference/data-api.html 返回: 添加的文本的id列表。 """ if kwargs: warnings.warn( "Method 'aadd_texts' of AstraDB vector store invoked with " f"unsupported arguments ({', '.join(sorted(kwargs.keys()))}), " "which will be ignored." ) await self.astra_env.aensure_db_setup() embedding_vectors = await self.embedding.aembed_documents(list(texts)) documents_to_insert = self._get_documents_to_insert( texts, embedding_vectors, metadatas, ids ) async def _handle_batch(document_batch: List[DocDict]) -> List[str]: im_result = await self.async_collection.insert_many( documents=document_batch, options={"ordered": False}, partial_failures_allowed=True, ) batch_inserted, missing_from_batch = self._get_missing_from_batch( document_batch, im_result ) async def _handle_missing_document(missing_document: DocDict) -> str: replacement_result = await self.async_collection.find_one_and_replace( filter={"_id": missing_document["_id"]}, replacement=missing_document, ) return replacement_result["data"]["document"]["_id"] _u_max_workers = ( overwrite_concurrency or self.bulk_insert_overwrite_concurrency ) batch_replaced = await gather_with_concurrency( _u_max_workers, *[_handle_missing_document(doc) for doc in missing_from_batch], ) return batch_inserted + batch_replaced _b_max_workers = batch_concurrency or self.bulk_insert_batch_concurrency all_ids_nested = await gather_with_concurrency( _b_max_workers, *[ _handle_batch(batch) for batch in batch_iterate( batch_size or self.batch_size, documents_to_insert, ) ], ) return [iid for id_list in all_ids_nested for iid in id_list]
[docs] def similarity_search_with_score_id_by_vector( self, embedding: List[float], k: int = 4, filter: Optional[Dict[str, Any]] = None, ) -> List[Tuple[Document, float, str]]: """返回与嵌入向量最相似的文档,包括分数和ID。 参数: embedding:要查找相似文档的嵌入。 k:要返回的文档数量。默认为4。 filter:要应用的元数据过滤器。 返回: 与查询向量最相似的文档列表,包括(文档,分数,ID)。 """ self.astra_env.ensure_db_setup() metadata_parameter = self._filter_to_metadata(filter) # hits = list( self.collection.paginated_find( filter=metadata_parameter, sort={"$vector": embedding}, options={"limit": k, "includeSimilarity": True}, projection={ "_id": 1, "content": 1, "metadata": 1, }, ) ) # return [ ( Document( page_content=hit["content"], metadata=hit["metadata"], ), hit["$similarity"], hit["_id"], ) for hit in hits ]
[docs] async def asimilarity_search_with_score_id_by_vector( self, embedding: List[float], k: int = 4, filter: Optional[Dict[str, Any]] = None, ) -> List[Tuple[Document, float, str]]: """返回与嵌入向量最相似的文档,包括分数和ID。 参数: embedding:要查找相似文档的嵌入。 k:要返回的文档数量。默认为4。 filter:要应用的元数据过滤器。 返回: 与查询向量最相似的文档列表,包括(文档,分数,ID)。 """ await self.astra_env.aensure_db_setup() metadata_parameter = self._filter_to_metadata(filter) # return [ ( Document( page_content=hit["content"], metadata=hit["metadata"], ), hit["$similarity"], hit["_id"], ) async for hit in self.async_collection.paginated_find( filter=metadata_parameter, sort={"$vector": embedding}, options={"limit": k, "includeSimilarity": True}, projection={ "_id": 1, "content": 1, "metadata": 1, }, ) ]
[docs] def similarity_search_with_score_id( self, query: str, k: int = 4, filter: Optional[Dict[str, Any]] = None, ) -> List[Tuple[Document, float, str]]: """返回与查询最相似的文档,包括分数和ID。 参数: query:要查找类似文档的查询。 k:要返回的文档数量。默认为4。 filter:要应用的元数据过滤器。 返回: 与查询最相似的文档的(文档,分数,ID)列表。 """ embedding_vector = self.embedding.embed_query(query) return self.similarity_search_with_score_id_by_vector( embedding=embedding_vector, k=k, filter=filter, )
[docs] async def asimilarity_search_with_score_id( self, query: str, k: int = 4, filter: Optional[Dict[str, Any]] = None, ) -> List[Tuple[Document, float, str]]: """返回与查询最相似的文档,包括分数和ID。 参数: query:要查找类似文档的查询。 k:要返回的文档数量。默认为4。 filter:要应用的元数据过滤器。 返回: 与查询最相似的文档的(文档,分数,ID)列表。 """ embedding_vector = await self.embedding.aembed_query(query) return await self.asimilarity_search_with_score_id_by_vector( embedding=embedding_vector, k=k, filter=filter, )
[docs] def similarity_search_with_score_by_vector( self, embedding: List[float], k: int = 4, filter: Optional[Dict[str, Any]] = None, ) -> List[Tuple[Document, float]]: """返回与嵌入向量最相似的文档及其分数。 参数: embedding: 要查找相似文档的嵌入。 k: 要返回的文档数量。默认为4。 filter: 要应用的元数据过滤器。 返回: 与查询向量最相似的文档列表及其分数。 """ return [ (doc, score) for (doc, score, doc_id) in self.similarity_search_with_score_id_by_vector( embedding=embedding, k=k, filter=filter, ) ]
[docs] async def asimilarity_search_with_score_by_vector( self, embedding: List[float], k: int = 4, filter: Optional[Dict[str, Any]] = None, ) -> List[Tuple[Document, float]]: """返回与嵌入向量最相似的文档及其分数。 参数: embedding: 要查找相似文档的嵌入。 k: 要返回的文档数量。默认为4。 filter: 要应用的元数据过滤器。 返回: 与查询向量最相似的文档列表及其分数。 """ return [ (doc, score) for ( doc, score, doc_id, ) in await self.asimilarity_search_with_score_id_by_vector( embedding=embedding, k=k, filter=filter, ) ]
[docs] def similarity_search_by_vector( self, embedding: List[float], k: int = 4, filter: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> List[Document]: """返回与嵌入向量最相似的文档。 参数: embedding:要查找与之相似文档的嵌入。 k:要返回的文档数量。默认为4。 filter:要应用的元数据过滤器。 返回: 与查询向量最相似的文档列表。 """ return [ doc for doc, _ in self.similarity_search_with_score_by_vector( embedding, k, filter=filter, ) ]
[docs] async def asimilarity_search_by_vector( self, embedding: List[float], k: int = 4, filter: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> List[Document]: """返回与嵌入向量最相似的文档。 参数: embedding:要查找与之相似文档的嵌入。 k:要返回的文档数量。默认为4。 filter:要应用的元数据过滤器。 返回: 与查询向量最相似的文档列表。 """ return [ doc for doc, _ in await self.asimilarity_search_with_score_by_vector( embedding, k, filter=filter, ) ]
[docs] def similarity_search_with_score( self, query: str, k: int = 4, filter: Optional[Dict[str, Any]] = None, ) -> List[Tuple[Document, float]]: """返回与查询最相似的文档及其分数。 参数: query:要查找相似文档的查询。 k:要返回的文档数量。默认为4。 filter:要应用的元数据过滤器。 返回: 与查询向量最相似的文档及其分数的列表。 """ embedding_vector = self.embedding.embed_query(query) return self.similarity_search_with_score_by_vector( embedding_vector, k, filter=filter, )
[docs] async def asimilarity_search_with_score( self, query: str, k: int = 4, filter: Optional[Dict[str, Any]] = None, ) -> List[Tuple[Document, float]]: """返回与查询最相似的文档及其分数。 参数: query:要查找相似文档的查询。 k:要返回的文档数量。默认为4。 filter:要应用的元数据过滤器。 返回: 与查询向量最相似的文档及其分数的列表。 """ embedding_vector = await self.embedding.aembed_query(query) return await self.asimilarity_search_with_score_by_vector( embedding_vector, k, filter=filter, )
@staticmethod def _get_mmr_hits( embedding: List[float], k: int, lambda_mult: float, prefetch_hits: List[DocDict] ) -> List[Document]: mmr_chosen_indices = maximal_marginal_relevance( np.array(embedding, dtype=np.float32), [prefetch_hit["$vector"] for prefetch_hit in prefetch_hits], k=k, lambda_mult=lambda_mult, ) mmr_hits = [ prefetch_hit for prefetch_index, prefetch_hit in enumerate(prefetch_hits) if prefetch_index in mmr_chosen_indices ] return [ Document( page_content=hit["content"], metadata=hit["metadata"], ) for hit in mmr_hits ]
[docs] def max_marginal_relevance_search_by_vector( self, embedding: List[float], k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, filter: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> List[Document]: """返回使用最大边际相关性选择的文档。 最大边际相关性优化了与查询的相似性和所选文档之间的多样性。 参数: embedding: 查找与之相似的文档的嵌入。 k: 要返回的文档数量。 fetch_k: 要获取的文档数量,以传递给MMR算法。 lambda_mult: 介于0和1之间的数字,确定结果之间多样性的程度, 其中0对应于最大多样性,1对应于最小多样性。 filter: 要应用的元数据过滤器。 返回: 通过最大边际相关性选择的文档列表。 """ self.astra_env.ensure_db_setup() metadata_parameter = self._filter_to_metadata(filter) prefetch_hits = list( self.collection.paginated_find( filter=metadata_parameter, sort={"$vector": embedding}, options={"limit": fetch_k, "includeSimilarity": True}, projection={ "_id": 1, "content": 1, "metadata": 1, "$vector": 1, }, ) ) return self._get_mmr_hits(embedding, k, lambda_mult, prefetch_hits)
[docs] async def amax_marginal_relevance_search_by_vector( self, embedding: List[float], k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, filter: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> List[Document]: """返回使用最大边际相关性选择的文档。 最大边际相关性优化了与查询的相似性和所选文档之间的多样性。 参数: embedding: 查找与之相似的文档的嵌入。 k: 要返回的文档数量。 fetch_k: 要获取的文档数量,以传递给MMR算法。 lambda_mult: 介于0和1之间的数字,确定结果之间多样性的程度, 其中0对应于最大多样性,1对应于最小多样性。 filter: 要应用的元数据过滤器。 返回: 通过最大边际相关性选择的文档列表。 """ await self.astra_env.aensure_db_setup() metadata_parameter = self._filter_to_metadata(filter) prefetch_hits = [ hit async for hit in self.async_collection.paginated_find( filter=metadata_parameter, sort={"$vector": embedding}, options={"limit": fetch_k, "includeSimilarity": True}, projection={ "_id": 1, "content": 1, "metadata": 1, "$vector": 1, }, ) ] return self._get_mmr_hits(embedding, k, lambda_mult, prefetch_hits)
@classmethod def _from_kwargs( cls: Type[ADBVST], embedding: Embeddings, **kwargs: Any, ) -> ADBVST: known_kwargs = { "collection_name", "token", "api_endpoint", "astra_db_client", "async_astra_db_client", "namespace", "metric", "batch_size", "bulk_insert_batch_concurrency", "bulk_insert_overwrite_concurrency", "bulk_delete_concurrency", "batch_concurrency", "overwrite_concurrency", } if kwargs: unknown_kwargs = set(kwargs.keys()) - known_kwargs if unknown_kwargs: warnings.warn( "Method 'from_texts' of AstraDB vector store invoked with " f"unsupported arguments ({', '.join(sorted(unknown_kwargs))}), " "which will be ignored." ) collection_name: str = kwargs["collection_name"] token = kwargs.get("token") api_endpoint = kwargs.get("api_endpoint") astra_db_client = kwargs.get("astra_db_client") async_astra_db_client = kwargs.get("async_astra_db_client") namespace = kwargs.get("namespace") metric = kwargs.get("metric") return cls( embedding=embedding, collection_name=collection_name, token=token, api_endpoint=api_endpoint, astra_db_client=astra_db_client, async_astra_db_client=async_astra_db_client, namespace=namespace, metric=metric, batch_size=kwargs.get("batch_size"), bulk_insert_batch_concurrency=kwargs.get("bulk_insert_batch_concurrency"), bulk_insert_overwrite_concurrency=kwargs.get( "bulk_insert_overwrite_concurrency" ), bulk_delete_concurrency=kwargs.get("bulk_delete_concurrency"), )
[docs] @classmethod def from_texts( cls: Type[ADBVST], texts: List[str], embedding: Embeddings, metadatas: Optional[List[dict]] = None, ids: Optional[List[str]] = None, **kwargs: Any, ) -> ADBVST: """从原始文本创建一个Astra DB向量存储。 参数: texts: 要插入的文本。 embedding: 存储中要使用的嵌入函数。 metadatas: 文本的元数据字典。 ids: 要与文本关联的ID。 **kwargs: 您可以传递任何您想要的参数 给 :meth:`~add_texts` 和/或 'AstraDB' 构造函数 (有关详细信息,请参阅这些方法)。这些参数将被 传递到相应的方法中。 返回: 一个 `AstraDb` 向量存储。 """ astra_db_store = AstraDB._from_kwargs(embedding, **kwargs) astra_db_store.add_texts( texts=texts, metadatas=metadatas, ids=ids, batch_size=kwargs.get("batch_size"), batch_concurrency=kwargs.get("batch_concurrency"), overwrite_concurrency=kwargs.get("overwrite_concurrency"), ) return astra_db_store # type: ignore[return-value]
[docs] @classmethod async def afrom_texts( cls: Type[ADBVST], texts: List[str], embedding: Embeddings, metadatas: Optional[List[dict]] = None, ids: Optional[List[str]] = None, **kwargs: Any, ) -> ADBVST: """从原始文本创建一个Astra DB向量存储。 参数: texts: 要插入的文本。 embedding: 存储中要使用的嵌入函数。 metadatas: 文本的元数据字典。 ids: 要与文本关联的ID。 **kwargs: 您可以传递任何您想要的参数 给 :meth:`~add_texts` 和/或 'AstraDB' 构造函数 (有关详细信息,请参阅这些方法)。这些参数将被 传递到相应的方法中。 返回: 一个 `AstraDb` 向量存储。 """ astra_db_store = AstraDB._from_kwargs(embedding, **kwargs) await astra_db_store.aadd_texts( texts=texts, metadatas=metadatas, ids=ids, batch_size=kwargs.get("batch_size"), batch_concurrency=kwargs.get("batch_concurrency"), overwrite_concurrency=kwargs.get("overwrite_concurrency"), ) return astra_db_store # type: ignore[return-value]
[docs] @classmethod def from_documents( cls: Type[ADBVST], documents: List[Document], embedding: Embeddings, **kwargs: Any, ) -> ADBVST: """从文档列表创建一个Astra DB向量存储。 这是一个实用方法,它委托给'from_texts'(请参阅该方法)。 参数:参见'from_texts',不同之处在于这里您必须提供'documents'来替代'texts'和'metadatas'。 返回: 一个`AstraDB`向量存储。 """ return super().from_documents(documents, embedding, **kwargs)