from __future__ import annotations
import uuid
import warnings
from concurrent.futures import ThreadPoolExecutor
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
Iterable,
List,
Optional,
Set,
Tuple,
Type,
TypeVar,
Union,
)
import numpy as np
from langchain_core._api.deprecation import deprecated
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.runnables.utils import gather_with_concurrency
from langchain_core.utils.iter import batch_iterate
from langchain_core.vectorstores import VectorStore
from langchain_community.utilities.astradb import (
SetupMode,
_AstraDBCollectionEnvironment,
)
from langchain_community.vectorstores.utils import maximal_marginal_relevance
if TYPE_CHECKING:
from astrapy.db import AstraDB as LibAstraDB
from astrapy.db import AsyncAstraDB
ADBVST = TypeVar("ADBVST", bound="AstraDB")
T = TypeVar("T")
U = TypeVar("U")
DocDict = Dict[str, Any] # dicts expressing entries to insert
# Batch/concurrency default values (if parameters not provided):
# Size of batches for bulk insertions:
# (20 is the max batch size for the HTTP API at the time of writing)
DEFAULT_BATCH_SIZE = 20
# Number of threads to insert batches concurrently:
DEFAULT_BULK_INSERT_BATCH_CONCURRENCY = 16
# Number of threads in a batch to insert pre-existing entries:
DEFAULT_BULK_INSERT_OVERWRITE_CONCURRENCY = 10
# Number of threads (for deleting multiple rows concurrently):
DEFAULT_BULK_DELETE_CONCURRENCY = 20
def _unique_list(lst: List[T], key: Callable[[T], U]) -> List[T]:
visited_keys: Set[U] = set()
new_lst = []
for item in lst:
item_key = key(item)
if item_key not in visited_keys:
visited_keys.add(item_key)
new_lst.append(item)
return new_lst
[docs]@deprecated(
since="0.0.21",
removal="0.3.0",
alternative_import="langchain_astradb.AstraDBVectorStore",
)
class AstraDB(VectorStore):
@staticmethod
def _filter_to_metadata(filter_dict: Optional[Dict[str, Any]]) -> Dict[str, Any]:
if filter_dict is None:
return {}
else:
metadata_filter = {}
for k, v in filter_dict.items():
if k and k[0] == "$":
if isinstance(v, list):
metadata_filter[k] = [AstraDB._filter_to_metadata(f) for f in v]
else:
metadata_filter[k] = AstraDB._filter_to_metadata(v) # type: ignore[assignment]
else:
metadata_filter[f"metadata.{k}"] = v
return metadata_filter
[docs] def __init__(
self,
*,
embedding: Embeddings,
collection_name: str,
token: Optional[str] = None,
api_endpoint: Optional[str] = None,
astra_db_client: Optional[LibAstraDB] = None,
async_astra_db_client: Optional[AsyncAstraDB] = None,
namespace: Optional[str] = None,
metric: Optional[str] = None,
batch_size: Optional[int] = None,
bulk_insert_batch_concurrency: Optional[int] = None,
bulk_insert_overwrite_concurrency: Optional[int] = None,
bulk_delete_concurrency: Optional[int] = None,
setup_mode: SetupMode = SetupMode.SYNC,
pre_delete_collection: bool = False,
) -> None:
"""封装了用于向量存储工作负载的 DataStax Astra DB。
要快速开始并了解详情,请访问
https://docs.datastax.com/en/astra/astra-db-vector/
示例:
.. code-block:: python
from langchain_community.vectorstores import AstraDB
from langchain_openai.embeddings import OpenAIEmbeddings
embeddings = OpenAIEmbeddings()
vectorstore = AstraDB(
embedding=embeddings,
collection_name="my_store",
token="AstraCS:...",
api_endpoint="https://<DB-ID>-<REGION>.apps.astra.datastax.com"
)
vectorstore.add_texts(["Giraffes", "All good here"])
results = vectorstore.similarity_search("Everything's ok", k=1)
参数:
embedding: 要使用的嵌入函数。
collection_name: 要创建/使用的 Astra DB 集合的名称。
token: 用于 Astra DB 使用的 API 令牌。
api_endpoint: API 端点的完整 URL,例如
`https://<DB-ID>-us-east1.apps.astra.datastax.com`。
astra_db_client: *token+api_endpoint 的替代方案*,
您可以传递一个已创建的 'astrapy.db.AstraDB' 实例。
async_astra_db_client: *token+api_endpoint 的替代方案*,
您可以传递一个已创建的 'astrapy.db.AsyncAstraDB' 实例。
namespace: 创建集合的命名空间(又名 keyspace)。
默认为数据库的 "default namespace"。
metric: 要在 Astra DB 中使用的相似性函数之一。
如果省略,将使用 Astra DB API 的默认值(即 "cosine" - 但是,
出于性能原因,如果嵌入已归一化为一,建议使用 "dot_product")。
batch_size: 用于批量插入的批量大小。
bulk_insert_batch_concurrency: 并发插入批次的线程数或协程数。
bulk_insert_overwrite_concurrency: 在批次中插入已存在条目的线程数或协程数。
bulk_delete_concurrency: 并发删除多行的线程数。
pre_delete_collection: 是否在创建集合之前删除集合。
如果为 False 并且集合已存在,则将使用该集合。
注意:
对于同步 :meth:`~add_texts` 中的并发性,根据经验,在典型客户端机器上,
建议将数量 bulk_insert_batch_concurrency * bulk_insert_overwrite_concurrency
保持远低于 1000,以避免耗尽客户端的多线程/网络资源。
预设的默认值相对保守,以满足大多数机器的规格,但一个明智的选择可能是:
- bulk_insert_batch_concurrency = 80
- bulk_insert_overwrite_concurrency = 10
需要进行一些实验来找到最佳结果,
这取决于机器/网络规格以及预期工作负载(特别是写入是现有 id 的更新的频率)。
请记住,您也可以将并发设置传递给 :meth:`~add_texts` 和 :meth:`~add_documents` 的各个调用。
"""
self.embedding = embedding
self.collection_name = collection_name
self.token = token
self.api_endpoint = api_endpoint
self.namespace = namespace
# Concurrency settings
self.batch_size: int = batch_size or DEFAULT_BATCH_SIZE
self.bulk_insert_batch_concurrency: int = (
bulk_insert_batch_concurrency or DEFAULT_BULK_INSERT_BATCH_CONCURRENCY
)
self.bulk_insert_overwrite_concurrency: int = (
bulk_insert_overwrite_concurrency
or DEFAULT_BULK_INSERT_OVERWRITE_CONCURRENCY
)
self.bulk_delete_concurrency: int = (
bulk_delete_concurrency or DEFAULT_BULK_DELETE_CONCURRENCY
)
# "vector-related" settings
self.metric = metric
embedding_dimension: Union[int, Awaitable[int], None] = None
if setup_mode == SetupMode.ASYNC:
embedding_dimension = self._aget_embedding_dimension()
elif setup_mode == SetupMode.SYNC:
embedding_dimension = self._get_embedding_dimension()
self.astra_env = _AstraDBCollectionEnvironment(
collection_name=collection_name,
token=token,
api_endpoint=api_endpoint,
astra_db_client=astra_db_client,
async_astra_db_client=async_astra_db_client,
namespace=namespace,
setup_mode=setup_mode,
pre_delete_collection=pre_delete_collection,
embedding_dimension=embedding_dimension,
metric=metric,
)
self.astra_db = self.astra_env.astra_db
self.async_astra_db = self.astra_env.async_astra_db
self.collection = self.astra_env.collection
self.async_collection = self.astra_env.async_collection
def _get_embedding_dimension(self) -> int:
return len(self.embedding.embed_query(text="This is a sample sentence."))
async def _aget_embedding_dimension(self) -> int:
return len(await self.embedding.aembed_query(text="This is a sample sentence."))
@property
def embeddings(self) -> Embeddings:
return self.embedding
@staticmethod
def _dont_flip_the_cos_score(similarity0to1: float) -> float:
"""保持客户端的相似度不变,因为它已经在[0:1]范围内。"""
return similarity0to1
def _select_relevance_score_fn(self) -> Callable[[float], float]:
"""基础的API调用已经返回了一个“适当的分数”,即在[0, 1]之间,数值越高表示*更相似*,因此这里的最终分数转换不会颠倒区间:
"""
return self._dont_flip_the_cos_score
[docs] def clear(self) -> None:
"""清空集合中存储的所有条目。"""
self.astra_env.ensure_db_setup()
self.collection.delete_many({})
[docs] async def aclear(self) -> None:
"""清空集合中存储的所有条目。"""
await self.astra_env.aensure_db_setup()
await self.async_collection.delete_many({}) # type: ignore[union-attr]
[docs] def delete_by_document_id(self, document_id: str) -> bool:
"""从存储中删除单个文档,给定其文档ID。
参数:
document_id:文档ID
返回值:
如果确实已删除文档,则为True,如果未找到ID,则为False。
"""
self.astra_env.ensure_db_setup()
deletion_response = self.collection.delete_one(document_id) # type: ignore[union-attr]
return ((deletion_response or {}).get("status") or {}).get(
"deletedCount", 0
) == 1
[docs] async def adelete_by_document_id(self, document_id: str) -> bool:
"""从存储中删除单个文档,给定其文档ID。
参数:
document_id:文档ID
返回值:
如果确实已删除文档,则为True,如果未找到ID,则为False。
"""
await self.astra_env.aensure_db_setup()
deletion_response = await self.async_collection.delete_one(document_id)
return ((deletion_response or {}).get("status") or {}).get(
"deletedCount", 0
) == 1
[docs] def delete(
self,
ids: Optional[List[str]] = None,
concurrency: Optional[int] = None,
**kwargs: Any,
) -> Optional[bool]:
"""根据向量ID删除。
参数:
ids:要删除的ID列表。
并发性:发出单个文档删除请求的线程的最大数量。
默认为实例级设置。
返回:
如果删除成功则为True,否则为False。
"""
if kwargs:
warnings.warn(
"Method 'delete' of AstraDB vector store invoked with "
f"unsupported arguments ({', '.join(sorted(kwargs.keys()))}), "
"which will be ignored."
)
if ids is None:
raise ValueError("No ids provided to delete.")
_max_workers = concurrency or self.bulk_delete_concurrency
with ThreadPoolExecutor(max_workers=_max_workers) as tpe:
_ = list(
tpe.map(
self.delete_by_document_id,
ids,
)
)
return True
[docs] async def adelete(
self,
ids: Optional[List[str]] = None,
concurrency: Optional[int] = None,
**kwargs: Any,
) -> Optional[bool]:
"""根据向量ID删除。
参数:
ids:要删除的ID列表。
concurrency:单个文档删除请求的最大并发数。
默认为实例级设置。
**kwargs:子类可能使用的其他关键字参数。
返回:
如果删除成功则为True,否则为False。
"""
if kwargs:
warnings.warn(
"Method 'adelete' of AstraDB vector store invoked with "
f"unsupported arguments ({', '.join(sorted(kwargs.keys()))}), "
"which will be ignored."
)
if ids is None:
raise ValueError("No ids provided to delete.")
return all(
await gather_with_concurrency(
concurrency, *[self.adelete_by_document_id(doc_id) for doc_id in ids]
)
)
[docs] def delete_collection(self) -> None:
"""完全从数据库中删除集合(与:meth:`~clear`相反,后者仅清空集合)。
存储的数据将丢失且无法恢复,资源将被释放。
请谨慎使用。
"""
self.astra_env.ensure_db_setup()
self.astra_db.delete_collection(
collection_name=self.collection_name,
)
[docs] async def adelete_collection(self) -> None:
"""完全从数据库中删除集合(与:meth:`~aclear`相反,后者仅清空集合)。
存储的数据将丢失且无法恢复,资源将被释放。
请谨慎使用。
"""
await self.astra_env.aensure_db_setup()
await self.async_astra_db.delete_collection(
collection_name=self.collection_name,
)
@staticmethod
def _get_documents_to_insert(
texts: Iterable[str],
embedding_vectors: List[List[float]],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
) -> List[DocDict]:
if ids is None:
ids = [uuid.uuid4().hex for _ in texts]
if metadatas is None:
metadatas = [{} for _ in texts]
#
documents_to_insert = [
{
"content": b_txt,
"_id": b_id,
"$vector": b_emb,
"metadata": b_md,
}
for b_txt, b_emb, b_id, b_md in zip(
texts,
embedding_vectors,
ids,
metadatas,
)
]
# make unique by id, keeping the last
uniqued_documents_to_insert = _unique_list(
documents_to_insert[::-1],
lambda document: document["_id"],
)[::-1]
return uniqued_documents_to_insert
@staticmethod
def _get_missing_from_batch(
document_batch: List[DocDict], insert_result: Dict[str, Any]
) -> Tuple[List[str], List[DocDict]]:
if "status" not in insert_result:
raise ValueError(
f"API Exception while running bulk insertion: {str(insert_result)}"
)
batch_inserted = insert_result["status"]["insertedIds"]
# estimation of the preexisting documents that failed
missed_inserted_ids = {document["_id"] for document in document_batch} - set(
batch_inserted
)
errors = insert_result.get("errors", [])
# careful for other sources of error other than "doc already exists"
num_errors = len(errors)
unexpected_errors = any(
error.get("errorCode") != "DOCUMENT_ALREADY_EXISTS" for error in errors
)
if num_errors != len(missed_inserted_ids) or unexpected_errors:
raise ValueError(
f"API Exception while running bulk insertion: {str(errors)}"
)
# deal with the missing insertions as upserts
missing_from_batch = [
document
for document in document_batch
if document["_id"] in missed_inserted_ids
]
return batch_inserted, missing_from_batch
[docs] def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
*,
batch_size: Optional[int] = None,
batch_concurrency: Optional[int] = None,
overwrite_concurrency: Optional[int] = None,
**kwargs: Any,
) -> List[str]:
"""通过嵌入将文本传递并将其添加到向量存储中。
如果传递了显式的ids,那些id已经在存储中的条目将被替换。
参数:
texts:要添加到向量存储中的文本。
metadatas:元数据的可选列表。
ids:id的可选列表。
batch_size:每个API调用中的文档数量。
检查底层Astra DB HTTP API规范以获取最大值
(在编写本文时为20)。如果未提供,默认为
实例级别设置。
batch_concurrency:处理插入批次的线程数
并发。如果未提供,默认为实例级别
设置。
overwrite_concurrency:处理线程数
每个批次中的现有文档(需要单独的API调用)。
如果未提供,默认为实例级别设置。
注意:
元数据字典中允许的字段名称有限制
来自底层Astra DB API。例如,`$`(美元符号)不能在字典键中使用。
有关详细信息,请参阅此文档:
https://docs.datastax.com/en/astra/astra-db-vector/api-reference/data-api.html
返回:
添加的文本的id列表。
"""
if kwargs:
warnings.warn(
"Method 'add_texts' of AstraDB vector store invoked with "
f"unsupported arguments ({', '.join(sorted(kwargs.keys()))}), "
"which will be ignored."
)
self.astra_env.ensure_db_setup()
embedding_vectors = self.embedding.embed_documents(list(texts))
documents_to_insert = self._get_documents_to_insert(
texts, embedding_vectors, metadatas, ids
)
def _handle_batch(document_batch: List[DocDict]) -> List[str]:
im_result = self.collection.insert_many(
documents=document_batch,
options={"ordered": False},
partial_failures_allowed=True,
)
batch_inserted, missing_from_batch = self._get_missing_from_batch(
document_batch, im_result
)
def _handle_missing_document(missing_document: DocDict) -> str:
replacement_result = self.collection.find_one_and_replace(
filter={"_id": missing_document["_id"]},
replacement=missing_document,
)
return replacement_result["data"]["document"]["_id"]
_u_max_workers = (
overwrite_concurrency or self.bulk_insert_overwrite_concurrency
)
with ThreadPoolExecutor(max_workers=_u_max_workers) as tpe2:
batch_replaced = list(
tpe2.map(
_handle_missing_document,
missing_from_batch,
)
)
return batch_inserted + batch_replaced
_b_max_workers = batch_concurrency or self.bulk_insert_batch_concurrency
with ThreadPoolExecutor(max_workers=_b_max_workers) as tpe:
all_ids_nested = tpe.map(
_handle_batch,
batch_iterate(
batch_size or self.batch_size,
documents_to_insert,
),
)
return [iid for id_list in all_ids_nested for iid in id_list]
[docs] async def aadd_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
*,
batch_size: Optional[int] = None,
batch_concurrency: Optional[int] = None,
overwrite_concurrency: Optional[int] = None,
**kwargs: Any,
) -> List[str]:
"""通过嵌入将文本传递并将其添加到向量存储中。
如果传递了显式的ids,那些id已经在存储中的条目将被替换。
参数:
texts:要添加到向量存储中的文本。
metadatas:元数据的可选列表。
ids:id的可选列表。
batch_size:每个API调用中的文档数量。
检查底层Astra DB HTTP API规范以获取最大值
(在编写本文时为20)。如果未提供,默认为
实例级别设置。
batch_concurrency:处理插入批次的线程数
并发。如果未提供,默认为实例级别
设置。
overwrite_concurrency:处理线程数
每个批次中的现有文档(需要单独的API调用)。
如果未提供,默认为实例级别设置。
注意:
元数据字典中允许的字段名称有限制
来自底层Astra DB API。例如,`$`(美元符号)不能在字典键中使用。
有关详细信息,请参阅此文档:
https://docs.datastax.com/en/astra/astra-db-vector/api-reference/data-api.html
返回:
添加的文本的id列表。
"""
if kwargs:
warnings.warn(
"Method 'aadd_texts' of AstraDB vector store invoked with "
f"unsupported arguments ({', '.join(sorted(kwargs.keys()))}), "
"which will be ignored."
)
await self.astra_env.aensure_db_setup()
embedding_vectors = await self.embedding.aembed_documents(list(texts))
documents_to_insert = self._get_documents_to_insert(
texts, embedding_vectors, metadatas, ids
)
async def _handle_batch(document_batch: List[DocDict]) -> List[str]:
im_result = await self.async_collection.insert_many(
documents=document_batch,
options={"ordered": False},
partial_failures_allowed=True,
)
batch_inserted, missing_from_batch = self._get_missing_from_batch(
document_batch, im_result
)
async def _handle_missing_document(missing_document: DocDict) -> str:
replacement_result = await self.async_collection.find_one_and_replace(
filter={"_id": missing_document["_id"]},
replacement=missing_document,
)
return replacement_result["data"]["document"]["_id"]
_u_max_workers = (
overwrite_concurrency or self.bulk_insert_overwrite_concurrency
)
batch_replaced = await gather_with_concurrency(
_u_max_workers,
*[_handle_missing_document(doc) for doc in missing_from_batch],
)
return batch_inserted + batch_replaced
_b_max_workers = batch_concurrency or self.bulk_insert_batch_concurrency
all_ids_nested = await gather_with_concurrency(
_b_max_workers,
*[
_handle_batch(batch)
for batch in batch_iterate(
batch_size or self.batch_size,
documents_to_insert,
)
],
)
return [iid for id_list in all_ids_nested for iid in id_list]
[docs] def similarity_search_with_score_id_by_vector(
self,
embedding: List[float],
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
) -> List[Tuple[Document, float, str]]:
"""返回与嵌入向量最相似的文档,包括分数和ID。
参数:
embedding:要查找相似文档的嵌入。
k:要返回的文档数量。默认为4。
filter:要应用的元数据过滤器。
返回:
与查询向量最相似的文档列表,包括(文档,分数,ID)。
"""
self.astra_env.ensure_db_setup()
metadata_parameter = self._filter_to_metadata(filter)
#
hits = list(
self.collection.paginated_find(
filter=metadata_parameter,
sort={"$vector": embedding},
options={"limit": k, "includeSimilarity": True},
projection={
"_id": 1,
"content": 1,
"metadata": 1,
},
)
)
#
return [
(
Document(
page_content=hit["content"],
metadata=hit["metadata"],
),
hit["$similarity"],
hit["_id"],
)
for hit in hits
]
[docs] async def asimilarity_search_with_score_id_by_vector(
self,
embedding: List[float],
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
) -> List[Tuple[Document, float, str]]:
"""返回与嵌入向量最相似的文档,包括分数和ID。
参数:
embedding:要查找相似文档的嵌入。
k:要返回的文档数量。默认为4。
filter:要应用的元数据过滤器。
返回:
与查询向量最相似的文档列表,包括(文档,分数,ID)。
"""
await self.astra_env.aensure_db_setup()
metadata_parameter = self._filter_to_metadata(filter)
#
return [
(
Document(
page_content=hit["content"],
metadata=hit["metadata"],
),
hit["$similarity"],
hit["_id"],
)
async for hit in self.async_collection.paginated_find(
filter=metadata_parameter,
sort={"$vector": embedding},
options={"limit": k, "includeSimilarity": True},
projection={
"_id": 1,
"content": 1,
"metadata": 1,
},
)
]
[docs] def similarity_search_with_score_id(
self,
query: str,
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
) -> List[Tuple[Document, float, str]]:
"""返回与查询最相似的文档,包括分数和ID。
参数:
query:要查找类似文档的查询。
k:要返回的文档数量。默认为4。
filter:要应用的元数据过滤器。
返回:
与查询最相似的文档的(文档,分数,ID)列表。
"""
embedding_vector = self.embedding.embed_query(query)
return self.similarity_search_with_score_id_by_vector(
embedding=embedding_vector,
k=k,
filter=filter,
)
[docs] async def asimilarity_search_with_score_id(
self,
query: str,
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
) -> List[Tuple[Document, float, str]]:
"""返回与查询最相似的文档,包括分数和ID。
参数:
query:要查找类似文档的查询。
k:要返回的文档数量。默认为4。
filter:要应用的元数据过滤器。
返回:
与查询最相似的文档的(文档,分数,ID)列表。
"""
embedding_vector = await self.embedding.aembed_query(query)
return await self.asimilarity_search_with_score_id_by_vector(
embedding=embedding_vector,
k=k,
filter=filter,
)
[docs] def similarity_search_with_score_by_vector(
self,
embedding: List[float],
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
) -> List[Tuple[Document, float]]:
"""返回与嵌入向量最相似的文档及其分数。
参数:
embedding: 要查找相似文档的嵌入。
k: 要返回的文档数量。默认为4。
filter: 要应用的元数据过滤器。
返回:
与查询向量最相似的文档列表及其分数。
"""
return [
(doc, score)
for (doc, score, doc_id) in self.similarity_search_with_score_id_by_vector(
embedding=embedding,
k=k,
filter=filter,
)
]
[docs] async def asimilarity_search_with_score_by_vector(
self,
embedding: List[float],
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
) -> List[Tuple[Document, float]]:
"""返回与嵌入向量最相似的文档及其分数。
参数:
embedding: 要查找相似文档的嵌入。
k: 要返回的文档数量。默认为4。
filter: 要应用的元数据过滤器。
返回:
与查询向量最相似的文档列表及其分数。
"""
return [
(doc, score)
for (
doc,
score,
doc_id,
) in await self.asimilarity_search_with_score_id_by_vector(
embedding=embedding,
k=k,
filter=filter,
)
]
[docs] def similarity_search(
self,
query: str,
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> List[Document]:
"""返回与查询最相似的文档。
参数:
query:要查找类似文档的查询。
k:要返回的文档数量。默认为4。
filter:要应用的元数据过滤器。
返回:
与查询最相似的文档列表。
"""
embedding_vector = self.embedding.embed_query(query)
return self.similarity_search_by_vector(
embedding_vector,
k,
filter=filter,
)
[docs] async def asimilarity_search(
self,
query: str,
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> List[Document]:
"""返回与查询最相似的文档。
参数:
query:要查找类似文档的查询。
k:要返回的文档数量。默认为4。
filter:要应用的元数据过滤器。
返回:
与查询最相似的文档列表。
"""
embedding_vector = await self.embedding.aembed_query(query)
return await self.asimilarity_search_by_vector(
embedding_vector,
k,
filter=filter,
)
[docs] def similarity_search_by_vector(
self,
embedding: List[float],
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> List[Document]:
"""返回与嵌入向量最相似的文档。
参数:
embedding:要查找与之相似文档的嵌入。
k:要返回的文档数量。默认为4。
filter:要应用的元数据过滤器。
返回:
与查询向量最相似的文档列表。
"""
return [
doc
for doc, _ in self.similarity_search_with_score_by_vector(
embedding,
k,
filter=filter,
)
]
[docs] async def asimilarity_search_by_vector(
self,
embedding: List[float],
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> List[Document]:
"""返回与嵌入向量最相似的文档。
参数:
embedding:要查找与之相似文档的嵌入。
k:要返回的文档数量。默认为4。
filter:要应用的元数据过滤器。
返回:
与查询向量最相似的文档列表。
"""
return [
doc
for doc, _ in await self.asimilarity_search_with_score_by_vector(
embedding,
k,
filter=filter,
)
]
[docs] def similarity_search_with_score(
self,
query: str,
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
) -> List[Tuple[Document, float]]:
"""返回与查询最相似的文档及其分数。
参数:
query:要查找相似文档的查询。
k:要返回的文档数量。默认为4。
filter:要应用的元数据过滤器。
返回:
与查询向量最相似的文档及其分数的列表。
"""
embedding_vector = self.embedding.embed_query(query)
return self.similarity_search_with_score_by_vector(
embedding_vector,
k,
filter=filter,
)
[docs] async def asimilarity_search_with_score(
self,
query: str,
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
) -> List[Tuple[Document, float]]:
"""返回与查询最相似的文档及其分数。
参数:
query:要查找相似文档的查询。
k:要返回的文档数量。默认为4。
filter:要应用的元数据过滤器。
返回:
与查询向量最相似的文档及其分数的列表。
"""
embedding_vector = await self.embedding.aembed_query(query)
return await self.asimilarity_search_with_score_by_vector(
embedding_vector,
k,
filter=filter,
)
@staticmethod
def _get_mmr_hits(
embedding: List[float], k: int, lambda_mult: float, prefetch_hits: List[DocDict]
) -> List[Document]:
mmr_chosen_indices = maximal_marginal_relevance(
np.array(embedding, dtype=np.float32),
[prefetch_hit["$vector"] for prefetch_hit in prefetch_hits],
k=k,
lambda_mult=lambda_mult,
)
mmr_hits = [
prefetch_hit
for prefetch_index, prefetch_hit in enumerate(prefetch_hits)
if prefetch_index in mmr_chosen_indices
]
return [
Document(
page_content=hit["content"],
metadata=hit["metadata"],
)
for hit in mmr_hits
]
[docs] def max_marginal_relevance_search_by_vector(
self,
embedding: List[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> List[Document]:
"""返回使用最大边际相关性选择的文档。
最大边际相关性优化了与查询的相似性和所选文档之间的多样性。
参数:
embedding: 查找与之相似的文档的嵌入。
k: 要返回的文档数量。
fetch_k: 要获取的文档数量,以传递给MMR算法。
lambda_mult: 介于0和1之间的数字,确定结果之间多样性的程度,
其中0对应于最大多样性,1对应于最小多样性。
filter: 要应用的元数据过滤器。
返回:
通过最大边际相关性选择的文档列表。
"""
self.astra_env.ensure_db_setup()
metadata_parameter = self._filter_to_metadata(filter)
prefetch_hits = list(
self.collection.paginated_find(
filter=metadata_parameter,
sort={"$vector": embedding},
options={"limit": fetch_k, "includeSimilarity": True},
projection={
"_id": 1,
"content": 1,
"metadata": 1,
"$vector": 1,
},
)
)
return self._get_mmr_hits(embedding, k, lambda_mult, prefetch_hits)
[docs] async def amax_marginal_relevance_search_by_vector(
self,
embedding: List[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> List[Document]:
"""返回使用最大边际相关性选择的文档。
最大边际相关性优化了与查询的相似性和所选文档之间的多样性。
参数:
embedding: 查找与之相似的文档的嵌入。
k: 要返回的文档数量。
fetch_k: 要获取的文档数量,以传递给MMR算法。
lambda_mult: 介于0和1之间的数字,确定结果之间多样性的程度,
其中0对应于最大多样性,1对应于最小多样性。
filter: 要应用的元数据过滤器。
返回:
通过最大边际相关性选择的文档列表。
"""
await self.astra_env.aensure_db_setup()
metadata_parameter = self._filter_to_metadata(filter)
prefetch_hits = [
hit
async for hit in self.async_collection.paginated_find(
filter=metadata_parameter,
sort={"$vector": embedding},
options={"limit": fetch_k, "includeSimilarity": True},
projection={
"_id": 1,
"content": 1,
"metadata": 1,
"$vector": 1,
},
)
]
return self._get_mmr_hits(embedding, k, lambda_mult, prefetch_hits)
[docs] def max_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> List[Document]:
"""返回使用最大边际相关性选择的文档。
最大边际相关性优化了与查询的相似性和所选文档之间的多样性。
参数:
query:要查找相似文档的查询。
k:要返回的文档数量。
fetch_k:要获取以传递给MMR算法的文档数量。
lambda_mult:0到1之间的数字,确定结果中多样性的程度,
0对应最大多样性,1对应最小多样性。
filter:要应用的元数据过滤器。
返回:
通过最大边际相关性选择的文档列表。
"""
embedding_vector = self.embedding.embed_query(query)
return self.max_marginal_relevance_search_by_vector(
embedding_vector,
k,
fetch_k,
lambda_mult=lambda_mult,
filter=filter,
)
[docs] async def amax_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> List[Document]:
"""返回使用最大边际相关性选择的文档。
最大边际相关性优化了与查询的相似性和所选文档之间的多样性。
参数:
query:要查找相似文档的查询。
k:要返回的文档数量。
fetch_k:要获取以传递给MMR算法的文档数量。
lambda_mult:0到1之间的数字,确定结果中多样性的程度,
0对应最大多样性,1对应最小多样性。
filter:要应用的元数据过滤器。
返回:
通过最大边际相关性选择的文档列表。
"""
embedding_vector = await self.embedding.aembed_query(query)
return await self.amax_marginal_relevance_search_by_vector(
embedding_vector,
k,
fetch_k,
lambda_mult=lambda_mult,
filter=filter,
)
@classmethod
def _from_kwargs(
cls: Type[ADBVST],
embedding: Embeddings,
**kwargs: Any,
) -> ADBVST:
known_kwargs = {
"collection_name",
"token",
"api_endpoint",
"astra_db_client",
"async_astra_db_client",
"namespace",
"metric",
"batch_size",
"bulk_insert_batch_concurrency",
"bulk_insert_overwrite_concurrency",
"bulk_delete_concurrency",
"batch_concurrency",
"overwrite_concurrency",
}
if kwargs:
unknown_kwargs = set(kwargs.keys()) - known_kwargs
if unknown_kwargs:
warnings.warn(
"Method 'from_texts' of AstraDB vector store invoked with "
f"unsupported arguments ({', '.join(sorted(unknown_kwargs))}), "
"which will be ignored."
)
collection_name: str = kwargs["collection_name"]
token = kwargs.get("token")
api_endpoint = kwargs.get("api_endpoint")
astra_db_client = kwargs.get("astra_db_client")
async_astra_db_client = kwargs.get("async_astra_db_client")
namespace = kwargs.get("namespace")
metric = kwargs.get("metric")
return cls(
embedding=embedding,
collection_name=collection_name,
token=token,
api_endpoint=api_endpoint,
astra_db_client=astra_db_client,
async_astra_db_client=async_astra_db_client,
namespace=namespace,
metric=metric,
batch_size=kwargs.get("batch_size"),
bulk_insert_batch_concurrency=kwargs.get("bulk_insert_batch_concurrency"),
bulk_insert_overwrite_concurrency=kwargs.get(
"bulk_insert_overwrite_concurrency"
),
bulk_delete_concurrency=kwargs.get("bulk_delete_concurrency"),
)
[docs] @classmethod
def from_texts(
cls: Type[ADBVST],
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> ADBVST:
"""从原始文本创建一个Astra DB向量存储。
参数:
texts: 要插入的文本。
embedding: 存储中要使用的嵌入函数。
metadatas: 文本的元数据字典。
ids: 要与文本关联的ID。
**kwargs: 您可以传递任何您想要的参数
给 :meth:`~add_texts` 和/或 'AstraDB' 构造函数
(有关详细信息,请参阅这些方法)。这些参数将被
传递到相应的方法中。
返回:
一个 `AstraDb` 向量存储。
"""
astra_db_store = AstraDB._from_kwargs(embedding, **kwargs)
astra_db_store.add_texts(
texts=texts,
metadatas=metadatas,
ids=ids,
batch_size=kwargs.get("batch_size"),
batch_concurrency=kwargs.get("batch_concurrency"),
overwrite_concurrency=kwargs.get("overwrite_concurrency"),
)
return astra_db_store # type: ignore[return-value]
[docs] @classmethod
async def afrom_texts(
cls: Type[ADBVST],
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> ADBVST:
"""从原始文本创建一个Astra DB向量存储。
参数:
texts: 要插入的文本。
embedding: 存储中要使用的嵌入函数。
metadatas: 文本的元数据字典。
ids: 要与文本关联的ID。
**kwargs: 您可以传递任何您想要的参数
给 :meth:`~add_texts` 和/或 'AstraDB' 构造函数
(有关详细信息,请参阅这些方法)。这些参数将被
传递到相应的方法中。
返回:
一个 `AstraDb` 向量存储。
"""
astra_db_store = AstraDB._from_kwargs(embedding, **kwargs)
await astra_db_store.aadd_texts(
texts=texts,
metadatas=metadatas,
ids=ids,
batch_size=kwargs.get("batch_size"),
batch_concurrency=kwargs.get("batch_concurrency"),
overwrite_concurrency=kwargs.get("overwrite_concurrency"),
)
return astra_db_store # type: ignore[return-value]
[docs] @classmethod
def from_documents(
cls: Type[ADBVST],
documents: List[Document],
embedding: Embeddings,
**kwargs: Any,
) -> ADBVST:
"""从文档列表创建一个Astra DB向量存储。
这是一个实用方法,它委托给'from_texts'(请参阅该方法)。
参数:参见'from_texts',不同之处在于这里您必须提供'documents'来替代'texts'和'metadatas'。
返回:
一个`AstraDB`向量存储。
"""
return super().from_documents(documents, embedding, **kwargs)