from __future__ import annotations
import functools
import uuid
import warnings
from itertools import islice
from operator import itemgetter
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Callable,
Dict,
Generator,
Iterable,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
)
import numpy as np
from langchain_core._api.deprecation import deprecated
from langchain_core.embeddings import Embeddings
from langchain_core.runnables.config import run_in_executor
from langchain_core.vectorstores import VectorStore
from langchain_community.docstore.document import Document
from langchain_community.vectorstores.utils import maximal_marginal_relevance
if TYPE_CHECKING:
from qdrant_client import grpc # noqa
from qdrant_client.conversions import common_types
from qdrant_client.http import models as rest
DictFilter = Dict[str, Union[str, int, bool, dict, list]]
MetadataFilter = Union[DictFilter, common_types.Filter]
[docs]class QdrantException(Exception):
"""与`Qdrant`相关的异常。"""
[docs]def sync_call_fallback(method: Callable) -> Callable:
"""如果异步方法未实现,则调用类的同步方法的装饰器。此装饰器可能仅用于在类中定义为异步的方法。
"""
@functools.wraps(method)
async def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
try:
return await method(self, *args, **kwargs)
except NotImplementedError:
# If the async method is not implemented, call the synchronous method
# by removing the first letter from the method name. For example,
# if the async method is called ``aaad_texts``, the synchronous method
# will be called ``aad_texts``.
return await run_in_executor(
None, getattr(self, method.__name__[1:]), *args, **kwargs
)
return wrapper
[docs]@deprecated(
since="0.0.37", removal="0.3.0", alternative_import="langchain_qdrant.Qdrant"
)
class Qdrant(VectorStore):
"""`Qdrant`向量存储。
要使用,您应该已安装``qdrant-client``包。
示例:
.. code-block:: python
from qdrant_client import QdrantClient
from langchain_community.vectorstores import Qdrant
client = QdrantClient()
collection_name = "MyCollection"
qdrant = Qdrant(client, collection_name, embedding_function)
"""
CONTENT_KEY = "page_content"
METADATA_KEY = "metadata"
VECTOR_NAME = None
[docs] def __init__(
self,
client: Any,
collection_name: str,
embeddings: Optional[Embeddings] = None,
content_payload_key: str = CONTENT_KEY,
metadata_payload_key: str = METADATA_KEY,
distance_strategy: str = "COSINE",
vector_name: Optional[str] = VECTOR_NAME,
async_client: Optional[Any] = None,
embedding_function: Optional[Callable] = None, # deprecated
):
"""使用必要的组件进行初始化。"""
try:
import qdrant_client
except ImportError:
raise ImportError(
"Could not import qdrant-client python package. "
"Please install it with `pip install qdrant-client`."
)
if not isinstance(client, qdrant_client.QdrantClient):
raise ValueError(
f"client should be an instance of qdrant_client.QdrantClient, "
f"got {type(client)}"
)
if async_client is not None and not isinstance(
async_client, qdrant_client.AsyncQdrantClient
):
raise ValueError(
f"async_client should be an instance of qdrant_client.AsyncQdrantClient"
f"got {type(async_client)}"
)
if embeddings is None and embedding_function is None:
raise ValueError(
"`embeddings` value can't be None. Pass `Embeddings` instance."
)
if embeddings is not None and embedding_function is not None:
raise ValueError(
"Both `embeddings` and `embedding_function` are passed. "
"Use `embeddings` only."
)
self._embeddings = embeddings
self._embeddings_function = embedding_function
self.client: qdrant_client.QdrantClient = client
self.async_client: Optional[qdrant_client.AsyncQdrantClient] = async_client
self.collection_name = collection_name
self.content_payload_key = content_payload_key or self.CONTENT_KEY
self.metadata_payload_key = metadata_payload_key or self.METADATA_KEY
self.vector_name = vector_name or self.VECTOR_NAME
if embedding_function is not None:
warnings.warn(
"Using `embedding_function` is deprecated. "
"Pass `Embeddings` instance to `embeddings` instead."
)
if not isinstance(embeddings, Embeddings):
warnings.warn(
"`embeddings` should be an instance of `Embeddings`."
"Using `embeddings` as `embedding_function` which is deprecated"
)
self._embeddings_function = embeddings
self._embeddings = None
self.distance_strategy = distance_strategy.upper()
@property
def embeddings(self) -> Optional[Embeddings]:
return self._embeddings
[docs] def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[Sequence[str]] = None,
batch_size: int = 64,
**kwargs: Any,
) -> List[str]:
"""运行更多的文本通过嵌入并添加到向量存储。
参数:
texts:要添加到向量存储的字符串的可迭代对象。
metadatas:与文本相关联的元数据的可选列表。
ids:
与文本相关联的id的可选列表。 Id必须是类似uuid的字符串。
batch_size:
每个请求上传多少向量。
默认值:64
返回:
将文本添加到向量存储中的id列表。
"""
added_ids = []
for batch_ids, points in self._generate_rest_batches(
texts, metadatas, ids, batch_size
):
self.client.upsert(
collection_name=self.collection_name, points=points, **kwargs
)
added_ids.extend(batch_ids)
return added_ids
[docs] @sync_call_fallback
async def aadd_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[Sequence[str]] = None,
batch_size: int = 64,
**kwargs: Any,
) -> List[str]:
"""运行更多的文本通过嵌入并添加到向量存储。
参数:
texts:要添加到向量存储的字符串的可迭代对象。
metadatas:与文本相关联的元数据的可选列表。
ids:
与文本相关联的id的可选列表。 Id必须是类似uuid的字符串。
batch_size:
每个请求上传多少向量。
默认值:64
返回:
将文本添加到向量存储中的id列表。
"""
from qdrant_client.local.async_qdrant_local import AsyncQdrantLocal
if self.async_client is None or isinstance(
self.async_client._client, AsyncQdrantLocal
):
raise NotImplementedError(
"QdrantLocal cannot interoperate with sync and async clients"
)
added_ids = []
async for batch_ids, points in self._agenerate_rest_batches(
texts, metadatas, ids, batch_size
):
await self.async_client.upsert(
collection_name=self.collection_name, points=points, **kwargs
)
added_ids.extend(batch_ids)
return added_ids
[docs] def similarity_search(
self,
query: str,
k: int = 4,
filter: Optional[MetadataFilter] = None,
search_params: Optional[common_types.SearchParams] = None,
offset: int = 0,
score_threshold: Optional[float] = None,
consistency: Optional[common_types.ReadConsistency] = None,
**kwargs: Any,
) -> List[Document]:
"""返回与查询最相似的文档。
参数:
query: 要查找类似文档的文本。
k: 要返回的文档数量。默认为4。
filter: 按元数据过滤。默认为None。
search_params: 附加的搜索参数
offset:
要返回的第一个结果的偏移量。
可用于分页结果。
注意:较大的偏移值可能会导致性能问题。
score_threshold:
定义结果的最小分数阈值。
如果定义了,那么不太相似的结果将不会被返回。
返回结果的分数可能高于或低于
阈值,具体取决于所使用的距离函数。
例如,对于余弦相似度,只会返回更高的分数。
consistency:
搜索的读一致性。定义在返回结果之前应查询多少副本。
值:
- int - 要查询的副本数量,所有查询的副本中应该存在这些值
- 'majority' - 查询所有副本,但返回大多数副本中存在的值
- 'quorum' - 查询大多数副本,返回所有这些副本中存在的值
- 'all' - 查询所有副本,并返回所有副本中存在的值
**kwargs:
要传递给QdrantClient.search()的任何其他命名参数
返回:
与查询最相似的文档列表。
"""
results = self.similarity_search_with_score(
query,
k,
filter=filter,
search_params=search_params,
offset=offset,
score_threshold=score_threshold,
consistency=consistency,
**kwargs,
)
return list(map(itemgetter(0), results))
[docs] @sync_call_fallback
async def asimilarity_search(
self,
query: str,
k: int = 4,
filter: Optional[MetadataFilter] = None,
**kwargs: Any,
) -> List[Document]:
"""返回与查询最相似的文档。
参数:
query:要查找类似文档的文本。
k:要返回的文档数量。默认为4。
filter:按元数据过滤。默认为None。
返回:
与查询最相似的文档列表。
"""
results = await self.asimilarity_search_with_score(query, k, filter, **kwargs)
return list(map(itemgetter(0), results))
[docs] def similarity_search_with_score(
self,
query: str,
k: int = 4,
filter: Optional[MetadataFilter] = None,
search_params: Optional[common_types.SearchParams] = None,
offset: int = 0,
score_threshold: Optional[float] = None,
consistency: Optional[common_types.ReadConsistency] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""返回与查询最相似的文档。
参数:
query: 要查找类似文档的文本。
k: 要返回的文档数量。默认为4。
filter: 按元数据过滤。默认为None。
search_params: 附加的搜索参数
offset:
要返回的第一个结果的偏移量。
可用于分页结果。
注意:较大的偏移值可能会导致性能问题。
score_threshold:
定义结果的最小分数阈值。
如果定义了,将不会返回较不相似的结果。
返回结果的分数可能高于或低于阈值,具体取决于所使用的距离函数。
例如,对于余弦相似度,只会返回更高的分数。
consistency:
搜索的读一致性。定义在返回结果之前应查询多少副本。
值:
- int - 要查询的副本数,所有查询的副本中应该存在这些值
- 'majority' - 查询所有副本,但返回大多数副本中存在的值
- 'quorum' - 查询大多数副本,返回所有这些副本中存在的值
- 'all' - 查询所有副本,并返回所有副本中存在的值
**kwargs:
要传递给QdrantClient.search()的其他命名参数
返回:
与查询文本最相似的文档列表,以及每个文档的距离。
"""
return self.similarity_search_with_score_by_vector(
self._embed_query(query),
k,
filter=filter,
search_params=search_params,
offset=offset,
score_threshold=score_threshold,
consistency=consistency,
**kwargs,
)
[docs] @sync_call_fallback
async def asimilarity_search_with_score(
self,
query: str,
k: int = 4,
filter: Optional[MetadataFilter] = None,
search_params: Optional[common_types.SearchParams] = None,
offset: int = 0,
score_threshold: Optional[float] = None,
consistency: Optional[common_types.ReadConsistency] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""返回与查询最相似的文档。
参数:
query:要查找类似文档的文本。
k:要返回的文档数量。默认为4。
filter:按元数据过滤。默认为None。
search_params:其他搜索参数
offset:
要返回的第一个结果的偏移量。
可用于分页结果。
注意:较大的偏移值可能会导致性能问题。
score_threshold:
定义结果的最小分数阈值。
如果定义了,将不返回较不相似的结果。
返回结果的分数可能高于或低于阈值,具体取决于所使用的距离函数。
例如,对于余弦相似度,只会返回较高的分数。
consistency:
搜索的读一致性。定义在返回结果之前应查询多少副本。
值:
- int - 要查询的副本数量,所有查询的副本中应该存在这些值
- 'majority' - 查询所有副本,但返回大多数副本中存在的值
- 'quorum' - 查询大多数副本,返回所有这些副本中存在的值
- 'all' - 查询所有副本,并返回所有副本中存在的值
**kwargs:
要传递给AsyncQdrantClient.Search()的其他命名参数。
返回:
与查询文本最相似的文档列表,以及每个文档的距离。
"""
query_embedding = await self._aembed_query(query)
return await self.asimilarity_search_with_score_by_vector(
query_embedding,
k,
filter=filter,
search_params=search_params,
offset=offset,
score_threshold=score_threshold,
consistency=consistency,
**kwargs,
)
[docs] def similarity_search_by_vector(
self,
embedding: List[float],
k: int = 4,
filter: Optional[MetadataFilter] = None,
search_params: Optional[common_types.SearchParams] = None,
offset: int = 0,
score_threshold: Optional[float] = None,
consistency: Optional[common_types.ReadConsistency] = None,
**kwargs: Any,
) -> List[Document]:
"""返回与嵌入向量最相似的文档。
参数:
embedding: 要查找相似文档的嵌入向量。
k: 要返回的文档数量。默认为4。
filter: 按元数据过滤。默认为None。
search_params: 附加的搜索参数
offset:
要返回的第一个结果的偏移量。
可用于分页结果。
注意:较大的偏移值可能会导致性能问题。
score_threshold:
定义结果的最小分数阈值。
如果定义了,不太相似的结果将不会被返回。
返回结果的分数可能高于或低于阈值,具体取决于所使用的距离函数。
例如,对于余弦相似度,只会返回更高的分数。
consistency:
搜索的读一致性。定义在返回结果之前应查询多少个副本。
值:
- int - 要查询的副本数量,所有查询的副本中应该存在这些值
- 'majority' - 查询所有副本,但返回大多数副本中存在的值
- 'quorum' - 查询大多数副本,返回所有这些副本中存在的值
- 'all' - 查询所有副本,并返回所有副本中存在的值
**kwargs:
要传递给QdrantClient.search()的其他命名参数
返回:
查询结果中最相似的文档列表。
"""
results = self.similarity_search_with_score_by_vector(
embedding,
k,
filter=filter,
search_params=search_params,
offset=offset,
score_threshold=score_threshold,
consistency=consistency,
**kwargs,
)
return list(map(itemgetter(0), results))
[docs] @sync_call_fallback
async def asimilarity_search_by_vector(
self,
embedding: List[float],
k: int = 4,
filter: Optional[MetadataFilter] = None,
search_params: Optional[common_types.SearchParams] = None,
offset: int = 0,
score_threshold: Optional[float] = None,
consistency: Optional[common_types.ReadConsistency] = None,
**kwargs: Any,
) -> List[Document]:
"""返回与嵌入向量最相似的文档。
参数:
embedding: 要查找相似文档的嵌入向量。
k: 要返回的文档数量。默认为4。
filter: 按元数据过滤。默认为None。
search_params: 附加的搜索参数
offset:
要返回的第一个结果的偏移量。
可用于分页结果。
注意:较大的偏移值可能会导致性能问题。
score_threshold:
定义结果的最小分数阈值。
如果定义了,不太相似的结果将不会返回。
返回结果的分数可能高于或低于
阈值,具体取决于所使用的距离函数。
例如,对于余弦相似度,只会返回更高的分数。
consistency:
搜索的读一致性。定义应查询多少副本
才能返回结果。
值:
- int - 要查询的副本数量,所有查询的副本中应该存在这些值
- 'majority' - 查询所有副本,但返回大多数副本中存在的值
- 'quorum' - 查询大多数副本,返回所有这些副本中存在的值
- 'all' - 查询所有副本,并返回所有副本中存在的值
**kwargs:
要传递给AsyncQdrantClient.Search()的其他命名参数。
返回:
与查询最相似的文档列表。
"""
results = await self.asimilarity_search_with_score_by_vector(
embedding,
k,
filter=filter,
search_params=search_params,
offset=offset,
score_threshold=score_threshold,
consistency=consistency,
**kwargs,
)
return list(map(itemgetter(0), results))
[docs] def similarity_search_with_score_by_vector(
self,
embedding: List[float],
k: int = 4,
filter: Optional[MetadataFilter] = None,
search_params: Optional[common_types.SearchParams] = None,
offset: int = 0,
score_threshold: Optional[float] = None,
consistency: Optional[common_types.ReadConsistency] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""返回与嵌入向量最相似的文档。
参数:
embedding: 要查找相似文档的嵌入向量。
k: 要返回的文档数量。默认为4。
filter: 按元数据过滤。默认为None。
search_params: 附加的搜索参数
offset:
要返回的第一个结果的偏移量。
可用于分页结果。
注意:较大的偏移值可能会导致性能问题。
score_threshold:
定义结果的最小分数阈值。
如果定义了,不太相似的结果将不会被返回。
返回结果的分数可能高于或低于阈值,具体取决于所使用的距离函数。
例如,对于余弦相似度,只会返回更高的分数。
consistency:
搜索的读一致性。定义在返回结果之前应查询多少副本。
值:
- int - 要查询的副本数量,所有查询的副本中应该存在这些值
- 'majority' - 查询所有副本,但返回大多数副本中存在的值
- 'quorum' - 查询大多数副本,返回所有这些副本中存在的值
- 'all' - 查询所有副本,并返回所有副本中存在的值
**kwargs:
要传递给QdrantClient.search()的任何其他命名参数
返回:
查询文本最相似的文档列表,以及每个文档的距离。
"""
if filter is not None and isinstance(filter, dict):
warnings.warn(
"Using dict as a `filter` is deprecated. Please use qdrant-client "
"filters directly: "
"https://qdrant.tech/documentation/concepts/filtering/",
DeprecationWarning,
)
qdrant_filter = self._qdrant_filter_from_dict(filter)
else:
qdrant_filter = filter
query_vector = embedding
if self.vector_name is not None:
query_vector = (self.vector_name, embedding) # type: ignore[assignment]
results = self.client.search(
collection_name=self.collection_name,
query_vector=query_vector,
query_filter=qdrant_filter,
search_params=search_params,
limit=k,
offset=offset,
with_payload=True,
with_vectors=False, # Langchain does not expect vectors to be returned
score_threshold=score_threshold,
consistency=consistency,
**kwargs,
)
return [
(
self._document_from_scored_point(
result,
self.collection_name,
self.content_payload_key,
self.metadata_payload_key,
),
result.score,
)
for result in results
]
[docs] @sync_call_fallback
async def asimilarity_search_with_score_by_vector(
self,
embedding: List[float],
k: int = 4,
filter: Optional[MetadataFilter] = None,
search_params: Optional[common_types.SearchParams] = None,
offset: int = 0,
score_threshold: Optional[float] = None,
consistency: Optional[common_types.ReadConsistency] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""返回与嵌入向量最相似的文档。
参数:
embedding: 要查找相似文档的嵌入向量。
k: 要返回的文档数量。默认为4。
filter: 根据元数据进行过滤。默认为None。
search_params: 附加的搜索参数
offset:
要返回的第一个结果的偏移量。
可用于分页结果。
注意:较大的偏移值可能会导致性能问题。
score_threshold:
定义结果的最小分数阈值。
如果定义了,不太相似的结果将不会被返回。
返回结果的分数可能高于或低于阈值,具体取决于所使用的距离函数。
例如,对于余弦相似度,只会返回更高的分数。
consistency:
搜索的读一致性。定义在返回结果之前应查询多少个副本。
值:
- int - 要查询的副本数量,所有查询的副本中应该存在这些值
- 'majority' - 查询所有副本,但返回大多数副本中存在的值
- 'quorum' - 查询大多数副本,返回所有这些副本中存在的值
- 'all' - 查询所有副本,并返回所有副本中存在的值
**kwargs:
要传递给AsyncQdrantClient.Search()的其他命名参数。
返回:
查询文本最相似的文档列表,以及每个文档的距离。
"""
from qdrant_client.local.async_qdrant_local import AsyncQdrantLocal
if self.async_client is None or isinstance(
self.async_client._client, AsyncQdrantLocal
):
raise NotImplementedError(
"QdrantLocal cannot interoperate with sync and async clients"
)
if filter is not None and isinstance(filter, dict):
warnings.warn(
"Using dict as a `filter` is deprecated. Please use qdrant-client "
"filters directly: "
"https://qdrant.tech/documentation/concepts/filtering/",
DeprecationWarning,
)
qdrant_filter = self._qdrant_filter_from_dict(filter)
else:
qdrant_filter = filter
query_vector = embedding
if self.vector_name is not None:
query_vector = (self.vector_name, embedding) # type: ignore[assignment]
results = await self.async_client.search(
collection_name=self.collection_name,
query_vector=query_vector,
query_filter=qdrant_filter,
search_params=search_params,
limit=k,
offset=offset,
with_payload=True,
with_vectors=False, # Langchain does not expect vectors to be returned
score_threshold=score_threshold,
consistency=consistency,
**kwargs,
)
return [
(
self._document_from_scored_point(
result,
self.collection_name,
self.content_payload_key,
self.metadata_payload_key,
),
result.score,
)
for result in results
]
[docs] def max_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[MetadataFilter] = None,
search_params: Optional[common_types.SearchParams] = None,
score_threshold: Optional[float] = None,
consistency: Optional[common_types.ReadConsistency] = None,
**kwargs: Any,
) -> List[Document]:
"""返回使用最大边际相关性选择的文档。
最大边际相关性优化了与查询的相似性和所选文档之间的多样性。
参数:
query: 要查找类似文档的文本。
k: 要返回的文档数量。默认为4。
fetch_k: 要获取以传递给MMR算法的文档数量。默认为20。
lambda_mult: 0到1之间的数字,确定结果之间多样性的程度,其中0对应最大多样性,1对应最小多样性。默认为0.5。
filter: 按元数据过滤。默认为None。
search_params: 附加的搜索参数。
score_threshold:
定义结果的最小分数阈值。
如果定义了,那么不会返回相似度较低的结果。
返回结果的分数可能高于或低于阈值,具体取决于所使用的距离函数。
例如,对于余弦相似度,只会返回更高的分数。
consistency:
搜索的读一致性。定义在返回结果之前应查询多少个副本。
值:
- int - 要查询的副本数量,所有查询的副本中应该存在这些值
- 'majority' - 查询所有副本,但返回大多数副本中存在的值
- 'quorum' - 查询大多数副本,返回所有这些副本中存在的值
- 'all' - 查询所有副本,并返回所有副本中存在的值
**kwargs:
要传递给QdrantClient.search()的任何其他命名参数。
返回:
通过最大边际相关性选择的文档列表。
"""
query_embedding = self._embed_query(query)
return self.max_marginal_relevance_search_by_vector(
query_embedding,
k=k,
fetch_k=fetch_k,
lambda_mult=lambda_mult,
filter=filter,
search_params=search_params,
score_threshold=score_threshold,
consistency=consistency,
**kwargs,
)
[docs] @sync_call_fallback
async def amax_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[MetadataFilter] = None,
search_params: Optional[common_types.SearchParams] = None,
score_threshold: Optional[float] = None,
consistency: Optional[common_types.ReadConsistency] = None,
**kwargs: Any,
) -> List[Document]:
"""返回使用最大边际相关性选择的文档。
最大边际相关性优化了与查询的相似性和所选文档之间的多样性。
参数:
query: 要查找类似文档的文本。
k: 要返回的文档数量。默认为4。
fetch_k: 要获取以传递给MMR算法的文档数量。默认为20。
lambda_mult: 介于0和1之间的数字,确定结果之间多样性的程度,0表示最大多样性,1表示最小多样性。默认为0.5。
filter: 按元数据过滤。默认为None。
search_params: 附加搜索参数
score_threshold:
定义结果的最小分数阈值。
如果定义了,不太相似的结果将不会返回。
返回结果的分数可能高于或低于阈值,具体取决于所使用的距离函数。
例如,对于余弦相似度,只会返回更高的分数。
consistency:
搜索的读一致性。定义在返回结果之前应查询多少副本。
值:
- int - 要查询的副本数量,所有查询的副本中应该存在这些值
- 'majority' - 查询所有副本,但返回大多数副本中存在的值
- 'quorum' - 查询大多数副本,返回所有这些副本中存在的值
- 'all' - 查询所有副本,并返回所有副本中存在的值
**kwargs:
传递给AsyncQdrantClient.Search()的任何其他命名参数。
返回:
通过最大边际相关性选择的文档列表。
"""
query_embedding = await self._aembed_query(query)
return await self.amax_marginal_relevance_search_by_vector(
query_embedding,
k=k,
fetch_k=fetch_k,
lambda_mult=lambda_mult,
filter=filter,
search_params=search_params,
score_threshold=score_threshold,
consistency=consistency,
**kwargs,
)
[docs] def max_marginal_relevance_search_by_vector(
self,
embedding: List[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[MetadataFilter] = None,
search_params: Optional[common_types.SearchParams] = None,
score_threshold: Optional[float] = None,
consistency: Optional[common_types.ReadConsistency] = None,
**kwargs: Any,
) -> List[Document]:
"""返回使用最大边际相关性选择的文档。
最大边际相关性优化了与查询的相似性和所选文档之间的多样性。
参数:
embedding: 用于查找类似文档的嵌入。
k: 要返回的文档数量。默认为4。
fetch_k: 要获取以传递给MMR算法的文档数量。
lambda_mult: 介于0和1之间的数字,确定结果之间多样性的程度,
其中0对应于最大多样性,1对应于最小多样性。
默认为0.5。
filter: 按元数据过滤。默认为None。
search_params: 附加的搜索参数
score_threshold:
定义结果的最小分数阈值。
如果定义了,不太相似的结果将不会返回。
返回结果的分数可能高于或低于阈值,具体取决于所使用的距离函数。
例如,对于余弦相似度,只会返回更高的分数。
consistency:
搜索的读一致性。定义在返回结果之前应查询多少副本。
值:
- int - 要查询的副本数量,所有查询的副本中应该存在这些值
- 'majority' - 查询所有副本,但返回大多数副本中存在的值
- 'quorum' - 查询大多数副本,返回所有这些副本中存在的值
- 'all' - 查询所有副本,并返回所有副本中存在的值
**kwargs:
要传递给QdrantClient.search()的任何其他命名参数
返回:
通过最大边际相关性选择的文档列表。
"""
results = self.max_marginal_relevance_search_with_score_by_vector(
embedding,
k=k,
fetch_k=fetch_k,
lambda_mult=lambda_mult,
filter=filter,
search_params=search_params,
score_threshold=score_threshold,
consistency=consistency,
**kwargs,
)
return list(map(itemgetter(0), results))
[docs] @sync_call_fallback
async def amax_marginal_relevance_search_by_vector(
self,
embedding: List[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[MetadataFilter] = None,
search_params: Optional[common_types.SearchParams] = None,
score_threshold: Optional[float] = None,
consistency: Optional[common_types.ReadConsistency] = None,
**kwargs: Any,
) -> List[Document]:
"""返回使用最大边际相关性选择的文档。
最大边际相关性优化了与查询的相似性和所选文档之间的多样性。
参数:
query:要查找类似文档的文本。
k:要返回的文档数量。默认为4。
fetch_k:要获取以传递给MMR算法的文档数量。
默认为20。
lambda_mult:0到1之间的数字,确定结果中多样性的程度,
其中0对应最大多样性,1对应最小多样性。
默认为0.5。
filter:按元数据过滤。默认为None。
search_params:额外的搜索参数
score_threshold:
为结果定义最小分数阈值。
如果定义了,那么不会返回较不相似的结果。
返回结果的分数可能高于或低于阈值,具体取决于所使用的距离函数。
例如,对于余弦相似度,只会返回更高的分数。
consistency:
搜索的读一致性。定义在返回结果之前应查询多少副本。
值:
- int - 要查询的副本数量,所有查询的副本中应该存在这些值
- 'majority' - 查询所有副本,但返回大多数副本中存在的值
- 'quorum' - 查询大多数副本,返回所有这些副本中存在的值
- 'all' - 查询所有副本,并返回所有副本中存在的值
**kwargs:
要传递给AsyncQdrantClient.Search()的其他命名参数。
返回:
通过最大边际相关性和距离选择的文档列表。
"""
results = await self.amax_marginal_relevance_search_with_score_by_vector(
embedding,
k=k,
fetch_k=fetch_k,
lambda_mult=lambda_mult,
filter=filter,
search_params=search_params,
score_threshold=score_threshold,
consistency=consistency,
**kwargs,
)
return list(map(itemgetter(0), results))
[docs] def max_marginal_relevance_search_with_score_by_vector(
self,
embedding: List[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[MetadataFilter] = None,
search_params: Optional[common_types.SearchParams] = None,
score_threshold: Optional[float] = None,
consistency: Optional[common_types.ReadConsistency] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""返回使用最大边际相关性选择的文档。
最大边际相关性优化了与查询的相似性和所选文档之间的多样性。
参数:
query:要查找类似文档的文本。
k:要返回的文档数量。默认为4。
fetch_k:要获取以传递给MMR算法的文档数量。默认为20。
lambda_mult:0到1之间的数字,确定结果中多样性的程度,0对应最大多样性,1对应最小多样性。默认为0.5。
filter:按元数据过滤。默认为None。
search_params:额外的搜索参数
score_threshold:
为结果定义最小分数阈值。
如果定义了,相似性较低的结果将不会被返回。
返回结果的分数可能高于或低于阈值,具体取决于所使用的距离函数。
例如,对于余弦相似度,只会返回较高的分数。
consistency:
搜索的读一致性。定义在返回结果之前应查询多少副本。
值:
- int - 要查询的副本数量,所有查询的副本中应该存在这些值
- 'majority' - 查询所有副本,但返回大多数副本中存在的值
- 'quorum' - 查询大多数副本,返回所有这些副本中存在的值
- 'all' - 查询所有副本,并返回所有副本中存在的值
**kwargs:
要传递给QdrantClient.search()的任何其他命名参数
返回:
通过最大边际相关性和距离选择的文档列表。
"""
query_vector = embedding
if self.vector_name is not None:
query_vector = (self.vector_name, query_vector) # type: ignore[assignment]
results = self.client.search(
collection_name=self.collection_name,
query_vector=query_vector,
query_filter=filter,
search_params=search_params,
limit=fetch_k,
with_payload=True,
with_vectors=True,
score_threshold=score_threshold,
consistency=consistency,
**kwargs,
)
embeddings = [
result.vector.get(self.vector_name) # type: ignore[index, union-attr]
if self.vector_name is not None
else result.vector
for result in results
]
mmr_selected = maximal_marginal_relevance(
np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult
)
return [
(
self._document_from_scored_point(
results[i],
self.collection_name,
self.content_payload_key,
self.metadata_payload_key,
),
results[i].score,
)
for i in mmr_selected
]
[docs] @sync_call_fallback
async def amax_marginal_relevance_search_with_score_by_vector(
self,
embedding: List[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[MetadataFilter] = None,
search_params: Optional[common_types.SearchParams] = None,
score_threshold: Optional[float] = None,
consistency: Optional[common_types.ReadConsistency] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""返回使用最大边际相关性选择的文档。
最大边际相关性优化了与查询的相似性和所选文档之间的多样性。
参数:
query:要查找类似文档的文本。
k:要返回的文档数量。默认为4。
fetch_k:要获取以传递给MMR算法的文档数量。
默认为20。
lambda_mult:0到1之间的数字,确定结果之间多样性的程度,
其中0对应最大多样性,1对应最小多样性。
默认为0.5。
返回:
通过最大边际相关性和距离选择的文档列表。
"""
from qdrant_client.local.async_qdrant_local import AsyncQdrantLocal
if self.async_client is None or isinstance(
self.async_client._client, AsyncQdrantLocal
):
raise NotImplementedError(
"QdrantLocal cannot interoperate with sync and async clients"
)
query_vector = embedding
if self.vector_name is not None:
query_vector = (self.vector_name, query_vector) # type: ignore[assignment]
results = await self.async_client.search(
collection_name=self.collection_name,
query_vector=query_vector,
query_filter=filter,
search_params=search_params,
limit=fetch_k,
with_payload=True,
with_vectors=True,
score_threshold=score_threshold,
consistency=consistency,
**kwargs,
)
embeddings = [
result.vector.get(self.vector_name) # type: ignore[index, union-attr]
if self.vector_name is not None
else result.vector
for result in results
]
mmr_selected = maximal_marginal_relevance(
np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult
)
return [
(
self._document_from_scored_point(
results[i],
self.collection_name,
self.content_payload_key,
self.metadata_payload_key,
),
results[i].score,
)
for i in mmr_selected
]
[docs] def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
"""根据向量ID或其他条件删除。
参数:
ids:要删除的ID列表。
**kwargs:子类可能使用的其他关键字参数。
返回:
如果删除成功则为True,否则为False。
"""
from qdrant_client.http import models as rest
result = self.client.delete(
collection_name=self.collection_name,
points_selector=ids,
)
return result.status == rest.UpdateStatus.COMPLETED
[docs] @sync_call_fallback
async def adelete(
self, ids: Optional[List[str]] = None, **kwargs: Any
) -> Optional[bool]:
"""根据向量ID或其他条件删除。
参数:
ids:要删除的ID列表。
**kwargs:子类可能使用的其他关键字参数。
返回:
如果删除成功则为True,否则为False。
"""
from qdrant_client.local.async_qdrant_local import AsyncQdrantLocal
if self.async_client is None or isinstance(
self.async_client._client, AsyncQdrantLocal
):
raise NotImplementedError(
"QdrantLocal cannot interoperate with sync and async clients"
)
from qdrant_client.http import models as rest
result = await self.async_client.delete(
collection_name=self.collection_name,
points_selector=ids,
)
return result.status == rest.UpdateStatus.COMPLETED
[docs] @classmethod
def from_texts(
cls: Type[Qdrant],
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
ids: Optional[Sequence[str]] = None,
location: Optional[str] = None,
url: Optional[str] = None,
port: Optional[int] = 6333,
grpc_port: int = 6334,
prefer_grpc: bool = False,
https: Optional[bool] = None,
api_key: Optional[str] = None,
prefix: Optional[str] = None,
timeout: Optional[float] = None,
host: Optional[str] = None,
path: Optional[str] = None,
collection_name: Optional[str] = None,
distance_func: str = "Cosine",
content_payload_key: str = CONTENT_KEY,
metadata_payload_key: str = METADATA_KEY,
vector_name: Optional[str] = VECTOR_NAME,
batch_size: int = 64,
shard_number: Optional[int] = None,
replication_factor: Optional[int] = None,
write_consistency_factor: Optional[int] = None,
on_disk_payload: Optional[bool] = None,
hnsw_config: Optional[common_types.HnswConfigDiff] = None,
optimizers_config: Optional[common_types.OptimizersConfigDiff] = None,
wal_config: Optional[common_types.WalConfigDiff] = None,
quantization_config: Optional[common_types.QuantizationConfig] = None,
init_from: Optional[common_types.InitFrom] = None,
on_disk: Optional[bool] = None,
force_recreate: bool = False,
**kwargs: Any,
) -> Qdrant:
"""从文本列表构建Qdrant包装器。
参数:
texts:要在Qdrant中索引的文本列表。
embedding:`Embeddings`的子类,负责文本向量化。
metadatas:
可选的元数据列表。如果提供,必须与文本列表的长度相同。
ids:
与文本关联的可选id列表。Ids必须是类似uuid的字符串。
location:
如果是 `:memory:` - 使用内存中的Qdrant实例。
如果是 `str` - 将其用作 `url` 参数。
如果是 `None` - 回退到依赖 `host` 和 `port` 参数。
url:主机或 "Optional[scheme], host, Optional[port], Optional[prefix]" 的字符串。默认值:`None`
port:REST API接口的端口。默认值:6333
grpc_port:gRPC接口的端口。默认值:6334
prefer_grpc:
如果为true - 在自定义方法中尽可能使用gPRC接口。默认值:False
https:如果为true - 使用HTTPS(SSL)协议。默认值:None
api_key:用于在Qdrant Cloud中进行身份验证的API密钥。默认值:None
prefix:
如果不是None - 将前缀添加到REST URL路径中。
示例:service/v1 将导致 REST API 的URL 为
http://localhost:6333/service/v1/{qdrant-endpoint}。默认值:None
timeout:
REST和gRPC API请求的超时时间。
默认值:REST为5.0秒,gRPC为无限制
host:
Qdrant服务的主机名。如果url和host都为None,则设置为 'localhost'。默认值:None
path:
在本地模式下存储向量的路径。默认值:None
collection_name:
要使用的Qdrant集合的名称。如果未提供,将随机创建一个。默认值:None
distance_func:
距离函数。其中之一:"Cosine" / "Euclid" / "Dot"。默认值:"Cosine"
content_payload_key:
用于存储文档内容的有效载荷键。
默认值:"page_content"
metadata_payload_key:
用于存储文档元数据的有效载荷键。
默认值:"metadata"
vector_name:
在Qdrant内部使用的向量名称。
默认值:None
batch_size:
每个请求上传多少个向量。
默认值:64
shard_number:集合中的分片数。默认值为1,最小值为1。
replication_factor:
集合的复制因子。默认值为1,最小值为1。
定义将创建每个分片的副本数量。
仅在分布式模式下生效。
write_consistency_factor:
集合的写入一致性因子。默认值为1,最小值为1。
定义多少个副本应用操作,我们才认为操作成功。
增加此数字将使集合更具抗不一致性能力,但如果副本不足,将导致操作失败。
不会对性能产生任何影响。
仅在分布式模式下生效。
on_disk_payload:
如果为true - 点的有效载荷将不会存储在内存中。
每次请求时都会从磁盘中读取它。
通过(稍微)增加响应时间,此设置可以节省RAM。
注意:那些涉及过滤并且被索引的有效载荷值仍保留在RAM中。
hnsw_config:HNSW索引的参数
optimizers_config:优化器的参数
wal_config:Write-Ahead-Log的参数
quantization_config:
量化的参数,如果为None - 将禁用量化
init_from:
使用存储在另一个集合中的数据初始化此集合
force_recreate:
强制重新创建集合
**kwargs:
直接传递到REST客户端初始化的其他参数
这是一个用户友好的接口,可以:
1. 为每个文本创建嵌入
2. 默认情况下将Qdrant数据库初始化为内存中的文档存储库
(并可覆盖为远程文档存储库)
3. 将文本嵌入添加到Qdrant数据库
这旨在是一个快速入门的方式。
示例:
.. code-block:: python
from langchain_community.vectorstores import Qdrant
from langchain_community.embeddings import OpenAIEmbeddings
embeddings = OpenAIEmbeddings()
qdrant = Qdrant.from_texts(texts, embeddings, "localhost")
"""
qdrant = cls.construct_instance(
texts,
embedding,
location,
url,
port,
grpc_port,
prefer_grpc,
https,
api_key,
prefix,
timeout,
host,
path,
collection_name,
distance_func,
content_payload_key,
metadata_payload_key,
vector_name,
shard_number,
replication_factor,
write_consistency_factor,
on_disk_payload,
hnsw_config,
optimizers_config,
wal_config,
quantization_config,
init_from,
on_disk,
force_recreate,
**kwargs,
)
qdrant.add_texts(texts, metadatas, ids, batch_size)
return qdrant
[docs] @classmethod
def from_existing_collection(
cls: Type[Qdrant],
embedding: Embeddings,
path: str,
collection_name: str,
location: Optional[str] = None,
url: Optional[str] = None,
port: Optional[int] = 6333,
grpc_port: int = 6334,
prefer_grpc: bool = False,
https: Optional[bool] = None,
api_key: Optional[str] = None,
prefix: Optional[str] = None,
timeout: Optional[float] = None,
host: Optional[str] = None,
**kwargs: Any,
) -> Qdrant:
"""获取现有Qdrant集合的实例。
该方法将返回存储的实例,而不会插入任何新的嵌入。
"""
client, async_client = cls._generate_clients(
location=location,
url=url,
port=port,
grpc_port=grpc_port,
prefer_grpc=prefer_grpc,
https=https,
api_key=api_key,
prefix=prefix,
timeout=timeout,
host=host,
path=path,
**kwargs,
)
return cls(
client=client,
async_client=async_client,
collection_name=collection_name,
embeddings=embedding,
**kwargs,
)
[docs] @classmethod
@sync_call_fallback
async def afrom_texts(
cls: Type[Qdrant],
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
ids: Optional[Sequence[str]] = None,
location: Optional[str] = None,
url: Optional[str] = None,
port: Optional[int] = 6333,
grpc_port: int = 6334,
prefer_grpc: bool = False,
https: Optional[bool] = None,
api_key: Optional[str] = None,
prefix: Optional[str] = None,
timeout: Optional[float] = None,
host: Optional[str] = None,
path: Optional[str] = None,
collection_name: Optional[str] = None,
distance_func: str = "Cosine",
content_payload_key: str = CONTENT_KEY,
metadata_payload_key: str = METADATA_KEY,
vector_name: Optional[str] = VECTOR_NAME,
batch_size: int = 64,
shard_number: Optional[int] = None,
replication_factor: Optional[int] = None,
write_consistency_factor: Optional[int] = None,
on_disk_payload: Optional[bool] = None,
hnsw_config: Optional[common_types.HnswConfigDiff] = None,
optimizers_config: Optional[common_types.OptimizersConfigDiff] = None,
wal_config: Optional[common_types.WalConfigDiff] = None,
quantization_config: Optional[common_types.QuantizationConfig] = None,
init_from: Optional[common_types.InitFrom] = None,
on_disk: Optional[bool] = None,
force_recreate: bool = False,
**kwargs: Any,
) -> Qdrant:
"""从文本列表构建Qdrant包装器。
参数:
texts:要在Qdrant中索引的文本列表。
embedding:`Embeddings`的子类,负责文本向量化。
metadatas:
可选的元数据列表。如果提供,必须与文本列表的长度相同。
ids:
与文本关联的可选id列表。Ids必须是类似uuid的字符串。
location:
如果是 `:memory:` - 使用内存中的Qdrant实例。
如果是 `str` - 将其用作 `url` 参数。
如果是 `None` - 回退到依赖 `host` 和 `port` 参数。
url:格式为 "Optional[scheme], host, Optional[port], Optional[prefix]" 的主机或字符串。默认值:`None`
port:REST API接口的端口。默认值:6333
grpc_port:gRPC接口的端口。默认值:6334
prefer_grpc:
如果为True - 在自定义方法中尽可能使用gRPC接口。默认值:False
https:如果为True - 使用HTTPS(SSL)协议。默认值:None
api_key:用于在Qdrant Cloud中进行身份验证的API密钥。默认值:None
prefix:
如果不是None - 将前缀添加到REST URL路径中。
示例:service/v1 将导致REST API的URL路径为
http://localhost:6333/service/v1/{qdrant-endpoint}。默认值:None
timeout:
REST和gRPC API请求的超时时间。
默认值:REST为5.0秒,gRPC为无限制
host:
Qdrant服务的主机名。如果url和host都为None,则设置为'localhost'。默认值:None
path:
在本地模式下存储向量的路径。默认值:None
collection_name:
要使用的Qdrant集合的名称。如果未提供,将随机创建。默认值:None
distance_func:
距离函数。可选值为:"Cosine" / "Euclid" / "Dot"。默认值:"Cosine"
content_payload_key:
用于存储文档内容的有效载荷键。默认值:"page_content"
metadata_payload_key:
用于存储文档元数据的有效载荷键。默认值:"metadata"
vector_name:
在Qdrant内部使用的向量名称。默认值:None
batch_size:
每个请求上传多少个向量。默认值:64
shard_number:集合中的分片数量。默认值为1,最小值为1。
replication_factor:
集合的复制因子。默认值为1,最小值为1。
定义每个分片将创建多少个副本。
仅在分布式模式下生效。
write_consistency_factor:
集合的写入一致性因子。默认值为1,最小值为1。
定义多少个副本应用操作,我们才认为操作成功。
增加此数字将使集合更具抗干扰性,但如果副本不足,将导致失败。
不会对性能产生任何影响。
仅在分布式模式下生效。
on_disk_payload:
如果为True - 点的有效载荷将不会存储在内存中。
每次请求时都会从磁盘中读取。此设置通过(轻微地)增加响应时间来节省RAM。
注意:涉及过滤并且已索引的有效载荷值仍保留在RAM中。
hnsw_config:HNSW索引的参数
optimizers_config:优化器的参数
wal_config:Write-Ahead-Log的参数
quantization_config:
量化的参数,如果为None - 将禁用量化
init_from:
使用存储在另一个集合中的数据初始化此集合
force_recreate:
强制重新创建集合
**kwargs:
直接传递给REST客户端初始化的其他参数
这是一个用户友好的接口,可以:
1. 为每个文本创建嵌入
2. 默认情况下将Qdrant数据库初始化为内存中的文档存储库
(并可覆盖为远程文档存储库)
3. 将文本嵌入添加到Qdrant数据库
这旨在是一个快速入门的方式。
示例:
.. code-block:: python
from langchain_community.vectorstores import Qdrant
from langchain_community.embeddings import OpenAIEmbeddings
embeddings = OpenAIEmbeddings()
qdrant = await Qdrant.afrom_texts(texts, embeddings, "localhost")
"""
qdrant = await cls.aconstruct_instance(
texts,
embedding,
location,
url,
port,
grpc_port,
prefer_grpc,
https,
api_key,
prefix,
timeout,
host,
path,
collection_name,
distance_func,
content_payload_key,
metadata_payload_key,
vector_name,
shard_number,
replication_factor,
write_consistency_factor,
on_disk_payload,
hnsw_config,
optimizers_config,
wal_config,
quantization_config,
init_from,
on_disk,
force_recreate,
**kwargs,
)
await qdrant.aadd_texts(texts, metadatas, ids, batch_size)
return qdrant
[docs] @classmethod
def construct_instance(
cls: Type[Qdrant],
texts: List[str],
embedding: Embeddings,
location: Optional[str] = None,
url: Optional[str] = None,
port: Optional[int] = 6333,
grpc_port: int = 6334,
prefer_grpc: bool = False,
https: Optional[bool] = None,
api_key: Optional[str] = None,
prefix: Optional[str] = None,
timeout: Optional[float] = None,
host: Optional[str] = None,
path: Optional[str] = None,
collection_name: Optional[str] = None,
distance_func: str = "Cosine",
content_payload_key: str = CONTENT_KEY,
metadata_payload_key: str = METADATA_KEY,
vector_name: Optional[str] = VECTOR_NAME,
shard_number: Optional[int] = None,
replication_factor: Optional[int] = None,
write_consistency_factor: Optional[int] = None,
on_disk_payload: Optional[bool] = None,
hnsw_config: Optional[common_types.HnswConfigDiff] = None,
optimizers_config: Optional[common_types.OptimizersConfigDiff] = None,
wal_config: Optional[common_types.WalConfigDiff] = None,
quantization_config: Optional[common_types.QuantizationConfig] = None,
init_from: Optional[common_types.InitFrom] = None,
on_disk: Optional[bool] = None,
force_recreate: bool = False,
**kwargs: Any,
) -> Qdrant:
try:
import qdrant_client # noqa
except ImportError:
raise ImportError(
"Could not import qdrant-client python package. "
"Please install it with `pip install qdrant-client`."
)
from grpc import RpcError
from qdrant_client.http import models as rest
from qdrant_client.http.exceptions import UnexpectedResponse
# Just do a single quick embedding to get vector size
partial_embeddings = embedding.embed_documents(texts[:1])
vector_size = len(partial_embeddings[0])
collection_name = collection_name or uuid.uuid4().hex
distance_func = distance_func.upper()
client, async_client = cls._generate_clients(
location=location,
url=url,
port=port,
grpc_port=grpc_port,
prefer_grpc=prefer_grpc,
https=https,
api_key=api_key,
prefix=prefix,
timeout=timeout,
host=host,
path=path,
**kwargs,
)
try:
# Skip any validation in case of forced collection recreate.
if force_recreate:
raise ValueError
# Get the vector configuration of the existing collection and vector, if it
# was specified. If the old configuration does not match the current one,
# an exception is being thrown.
collection_info = client.get_collection(collection_name=collection_name)
current_vector_config = collection_info.config.params.vectors
if isinstance(current_vector_config, dict) and vector_name is not None:
if vector_name not in current_vector_config:
raise QdrantException(
f"Existing Qdrant collection {collection_name} does not "
f"contain vector named {vector_name}. Did you mean one of the "
f"existing vectors: {', '.join(current_vector_config.keys())}? "
f"If you want to recreate the collection, set `force_recreate` "
f"parameter to `True`."
)
current_vector_config = current_vector_config.get(vector_name) # type: ignore[assignment]
elif isinstance(current_vector_config, dict) and vector_name is None:
raise QdrantException(
f"Existing Qdrant collection {collection_name} uses named vectors. "
f"If you want to reuse it, please set `vector_name` to any of the "
f"existing named vectors: "
f"{', '.join(current_vector_config.keys())}." # noqa
f"If you want to recreate the collection, set `force_recreate` "
f"parameter to `True`."
)
elif (
not isinstance(current_vector_config, dict) and vector_name is not None
):
raise QdrantException(
f"Existing Qdrant collection {collection_name} doesn't use named "
f"vectors. If you want to reuse it, please set `vector_name` to "
f"`None`. If you want to recreate the collection, set "
f"`force_recreate` parameter to `True`."
)
# Check if the vector configuration has the same dimensionality.
if current_vector_config.size != vector_size: # type: ignore[union-attr]
raise QdrantException(
f"Existing Qdrant collection is configured for vectors with "
f"{current_vector_config.size} " # type: ignore[union-attr]
f"dimensions. Selected embeddings are {vector_size}-dimensional. "
f"If you want to recreate the collection, set `force_recreate` "
f"parameter to `True`."
)
current_distance_func = (
current_vector_config.distance.name.upper() # type: ignore[union-attr]
)
if current_distance_func != distance_func:
raise QdrantException(
f"Existing Qdrant collection is configured for "
f"{current_distance_func} similarity, but requested "
f"{distance_func}. Please set `distance_func` parameter to "
f"`{current_distance_func}` if you want to reuse it. "
f"If you want to recreate the collection, set `force_recreate` "
f"parameter to `True`."
)
except (UnexpectedResponse, RpcError, ValueError):
vectors_config = rest.VectorParams(
size=vector_size,
distance=rest.Distance[distance_func],
on_disk=on_disk,
)
# If vector name was provided, we're going to use the named vectors feature
# with just a single vector.
if vector_name is not None:
vectors_config = { # type: ignore[assignment]
vector_name: vectors_config,
}
client.recreate_collection(
collection_name=collection_name,
vectors_config=vectors_config,
shard_number=shard_number,
replication_factor=replication_factor,
write_consistency_factor=write_consistency_factor,
on_disk_payload=on_disk_payload,
hnsw_config=hnsw_config,
optimizers_config=optimizers_config,
wal_config=wal_config,
quantization_config=quantization_config,
init_from=init_from,
timeout=timeout, # type: ignore[arg-type]
)
qdrant = cls(
client=client,
collection_name=collection_name,
embeddings=embedding,
content_payload_key=content_payload_key,
metadata_payload_key=metadata_payload_key,
distance_strategy=distance_func,
vector_name=vector_name,
async_client=async_client,
)
return qdrant
[docs] @classmethod
async def aconstruct_instance(
cls: Type[Qdrant],
texts: List[str],
embedding: Embeddings,
location: Optional[str] = None,
url: Optional[str] = None,
port: Optional[int] = 6333,
grpc_port: int = 6334,
prefer_grpc: bool = False,
https: Optional[bool] = None,
api_key: Optional[str] = None,
prefix: Optional[str] = None,
timeout: Optional[float] = None,
host: Optional[str] = None,
path: Optional[str] = None,
collection_name: Optional[str] = None,
distance_func: str = "Cosine",
content_payload_key: str = CONTENT_KEY,
metadata_payload_key: str = METADATA_KEY,
vector_name: Optional[str] = VECTOR_NAME,
shard_number: Optional[int] = None,
replication_factor: Optional[int] = None,
write_consistency_factor: Optional[int] = None,
on_disk_payload: Optional[bool] = None,
hnsw_config: Optional[common_types.HnswConfigDiff] = None,
optimizers_config: Optional[common_types.OptimizersConfigDiff] = None,
wal_config: Optional[common_types.WalConfigDiff] = None,
quantization_config: Optional[common_types.QuantizationConfig] = None,
init_from: Optional[common_types.InitFrom] = None,
on_disk: Optional[bool] = None,
force_recreate: bool = False,
**kwargs: Any,
) -> Qdrant:
try:
import qdrant_client # noqa
except ImportError:
raise ImportError(
"Could not import qdrant-client python package. "
"Please install it with `pip install qdrant-client`."
)
from grpc import RpcError
from qdrant_client.http import models as rest
from qdrant_client.http.exceptions import UnexpectedResponse
# Just do a single quick embedding to get vector size
partial_embeddings = await embedding.aembed_documents(texts[:1])
vector_size = len(partial_embeddings[0])
collection_name = collection_name or uuid.uuid4().hex
distance_func = distance_func.upper()
client, async_client = cls._generate_clients(
location=location,
url=url,
port=port,
grpc_port=grpc_port,
prefer_grpc=prefer_grpc,
https=https,
api_key=api_key,
prefix=prefix,
timeout=timeout,
host=host,
path=path,
**kwargs,
)
try:
# Skip any validation in case of forced collection recreate.
if force_recreate:
raise ValueError
# Get the vector configuration of the existing collection and vector, if it
# was specified. If the old configuration does not match the current one,
# an exception is being thrown.
collection_info = client.get_collection(collection_name=collection_name)
current_vector_config = collection_info.config.params.vectors
if isinstance(current_vector_config, dict) and vector_name is not None:
if vector_name not in current_vector_config:
raise QdrantException(
f"Existing Qdrant collection {collection_name} does not "
f"contain vector named {vector_name}. Did you mean one of the "
f"existing vectors: {', '.join(current_vector_config.keys())}? "
f"If you want to recreate the collection, set `force_recreate` "
f"parameter to `True`."
)
current_vector_config = current_vector_config.get(vector_name) # type: ignore[assignment]
elif isinstance(current_vector_config, dict) and vector_name is None:
raise QdrantException(
f"Existing Qdrant collection {collection_name} uses named vectors. "
f"If you want to reuse it, please set `vector_name` to any of the "
f"existing named vectors: "
f"{', '.join(current_vector_config.keys())}." # noqa
f"If you want to recreate the collection, set `force_recreate` "
f"parameter to `True`."
)
elif (
not isinstance(current_vector_config, dict) and vector_name is not None
):
raise QdrantException(
f"Existing Qdrant collection {collection_name} doesn't use named "
f"vectors. If you want to reuse it, please set `vector_name` to "
f"`None`. If you want to recreate the collection, set "
f"`force_recreate` parameter to `True`."
)
# Check if the vector configuration has the same dimensionality.
if current_vector_config.size != vector_size: # type: ignore[union-attr]
raise QdrantException(
f"Existing Qdrant collection is configured for vectors with "
f"{current_vector_config.size} " # type: ignore[union-attr]
f"dimensions. Selected embeddings are {vector_size}-dimensional. "
f"If you want to recreate the collection, set `force_recreate` "
f"parameter to `True`."
)
current_distance_func = (
current_vector_config.distance.name.upper() # type: ignore[union-attr]
)
if current_distance_func != distance_func:
raise QdrantException(
f"Existing Qdrant collection is configured for "
f"{current_vector_config.distance} " # type: ignore[union-attr]
f"similarity. Please set `distance_func` parameter to "
f"`{distance_func}` if you want to reuse it. If you want to "
f"recreate the collection, set `force_recreate` parameter to "
f"`True`."
)
except (UnexpectedResponse, RpcError, ValueError):
vectors_config = rest.VectorParams(
size=vector_size,
distance=rest.Distance[distance_func],
on_disk=on_disk,
)
# If vector name was provided, we're going to use the named vectors feature
# with just a single vector.
if vector_name is not None:
vectors_config = { # type: ignore[assignment]
vector_name: vectors_config,
}
client.recreate_collection(
collection_name=collection_name,
vectors_config=vectors_config,
shard_number=shard_number,
replication_factor=replication_factor,
write_consistency_factor=write_consistency_factor,
on_disk_payload=on_disk_payload,
hnsw_config=hnsw_config,
optimizers_config=optimizers_config,
wal_config=wal_config,
quantization_config=quantization_config,
init_from=init_from,
timeout=timeout, # type: ignore[arg-type]
)
qdrant = cls(
client=client,
collection_name=collection_name,
embeddings=embedding,
content_payload_key=content_payload_key,
metadata_payload_key=metadata_payload_key,
distance_strategy=distance_func,
vector_name=vector_name,
async_client=async_client,
)
return qdrant
@staticmethod
def _cosine_relevance_score_fn(distance: float) -> float:
"""将距离归一化到一个范围为[0, 1]的分数。"""
return (distance + 1.0) / 2.0
def _select_relevance_score_fn(self) -> Callable[[float], float]:
"""“正确”的相关性函数可能会有所不同,取决于一些因素,包括:
- 向量存储中使用的距离/相似度度量
- 嵌入的规模(OpenAI的是单位规范化的。许多其他嵌入不是!)
- 嵌入的维度
- 等等。
"""
if self.distance_strategy == "COSINE":
return self._cosine_relevance_score_fn
elif self.distance_strategy == "DOT":
return self._max_inner_product_relevance_score_fn
elif self.distance_strategy == "EUCLID":
return self._euclidean_relevance_score_fn
else:
raise ValueError(
"Unknown distance strategy, must be cosine, "
"max_inner_product, or euclidean"
)
def _similarity_search_with_relevance_scores(
self,
query: str,
k: int = 4,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""返回文档和相关性得分在[0, 1]范围内。
0表示不相似,1表示最相似。
参数:
query: 输入文本
k: 要返回的文档数量。默认为4。
**kwargs: 要传递给相似性搜索的kwargs。应包括:
score_threshold: 可选,介于0到1之间的浮点值,用于过滤检索到的文档集合
返回:
元组列表(doc,相似度得分)
"""
return self.similarity_search_with_score(query, k, **kwargs)
@classmethod
def _build_payloads(
cls,
texts: Iterable[str],
metadatas: Optional[List[dict]],
content_payload_key: str,
metadata_payload_key: str,
) -> List[dict]:
payloads = []
for i, text in enumerate(texts):
if text is None:
raise ValueError(
"At least one of the texts is None. Please remove it before "
"calling .from_texts or .add_texts on Qdrant instance."
)
metadata = metadatas[i] if metadatas is not None else None
payloads.append(
{
content_payload_key: text,
metadata_payload_key: metadata,
}
)
return payloads
@classmethod
def _document_from_scored_point(
cls,
scored_point: Any,
collection_name: str,
content_payload_key: str,
metadata_payload_key: str,
) -> Document:
metadata = scored_point.payload.get(metadata_payload_key) or {}
metadata["_id"] = scored_point.id
metadata["_collection_name"] = collection_name
return Document(
page_content=scored_point.payload.get(content_payload_key),
metadata=metadata,
)
def _build_condition(self, key: str, value: Any) -> List[rest.FieldCondition]:
from qdrant_client.http import models as rest
out = []
if isinstance(value, dict):
for _key, value in value.items():
out.extend(self._build_condition(f"{key}.{_key}", value))
elif isinstance(value, list):
for _value in value:
if isinstance(_value, dict):
out.extend(self._build_condition(f"{key}[]", _value))
else:
out.extend(self._build_condition(f"{key}", _value))
else:
out.append(
rest.FieldCondition(
key=f"{self.metadata_payload_key}.{key}",
match=rest.MatchValue(value=value),
)
)
return out
def _qdrant_filter_from_dict(
self, filter: Optional[DictFilter]
) -> Optional[rest.Filter]:
from qdrant_client.http import models as rest
if not filter:
return None
return rest.Filter(
must=[
condition
for key, value in filter.items()
for condition in self._build_condition(key, value)
]
)
def _embed_query(self, query: str) -> List[float]:
"""嵌入查询文本。
用于与`embedding_function`参数提供向后兼容性。
参数:
query: 查询文本。
返回:
代表查询嵌入的浮点数列表。
"""
if self.embeddings is not None:
embedding = self.embeddings.embed_query(query)
else:
if self._embeddings_function is not None:
embedding = self._embeddings_function(query)
else:
raise ValueError("Neither of embeddings or embedding_function is set")
return embedding.tolist() if hasattr(embedding, "tolist") else embedding
async def _aembed_query(self, query: str) -> List[float]:
"""异步嵌入查询文本。
用于与`embedding_function`参数提供向后兼容性。
参数:
query: 查询文本。
返回:
代表查询嵌入的浮点数列表。
"""
if self.embeddings is not None:
embedding = await self.embeddings.aembed_query(query)
else:
if self._embeddings_function is not None:
embedding = self._embeddings_function(query)
else:
raise ValueError("Neither of embeddings or embedding_function is set")
return embedding.tolist() if hasattr(embedding, "tolist") else embedding
def _embed_texts(self, texts: Iterable[str]) -> List[List[float]]:
"""嵌入搜索文本。
用于与`embedding_function`参数提供向后兼容性。
参数:
texts:要嵌入的文本的可迭代对象。
返回:
代表文本嵌入的浮点数列表。
"""
if self.embeddings is not None:
embeddings = self.embeddings.embed_documents(list(texts))
if hasattr(embeddings, "tolist"):
embeddings = embeddings.tolist()
elif self._embeddings_function is not None:
embeddings = []
for text in texts:
embedding = self._embeddings_function(text)
if hasattr(embeddings, "tolist"):
embedding = embedding.tolist()
embeddings.append(embedding)
else:
raise ValueError("Neither of embeddings or embedding_function is set")
return embeddings
async def _aembed_texts(self, texts: Iterable[str]) -> List[List[float]]:
"""嵌入搜索文本。
用于与`embedding_function`参数提供向后兼容性。
参数:
texts:要嵌入的文本的可迭代对象。
返回:
代表文本嵌入的浮点数列表。
"""
if self.embeddings is not None:
embeddings = await self.embeddings.aembed_documents(list(texts))
if hasattr(embeddings, "tolist"):
embeddings = embeddings.tolist()
elif self._embeddings_function is not None:
embeddings = []
for text in texts:
embedding = self._embeddings_function(text)
if hasattr(embeddings, "tolist"):
embedding = embedding.tolist()
embeddings.append(embedding)
else:
raise ValueError("Neither of embeddings or embedding_function is set")
return embeddings
def _generate_rest_batches(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[Sequence[str]] = None,
batch_size: int = 64,
) -> Generator[Tuple[List[str], List[rest.PointStruct]], None, None]:
from qdrant_client.http import models as rest
texts_iterator = iter(texts)
metadatas_iterator = iter(metadatas or [])
ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)])
while batch_texts := list(islice(texts_iterator, batch_size)):
# Take the corresponding metadata and id for each text in a batch
batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None
batch_ids = list(islice(ids_iterator, batch_size))
# Generate the embeddings for all the texts in a batch
batch_embeddings = self._embed_texts(batch_texts)
points = [
rest.PointStruct(
id=point_id,
vector=vector
if self.vector_name is None
else {self.vector_name: vector},
payload=payload,
)
for point_id, vector, payload in zip(
batch_ids,
batch_embeddings,
self._build_payloads(
batch_texts,
batch_metadatas,
self.content_payload_key,
self.metadata_payload_key,
),
)
]
yield batch_ids, points
async def _agenerate_rest_batches(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[Sequence[str]] = None,
batch_size: int = 64,
) -> AsyncGenerator[Tuple[List[str], List[rest.PointStruct]], None]:
from qdrant_client.http import models as rest
texts_iterator = iter(texts)
metadatas_iterator = iter(metadatas or [])
ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)])
while batch_texts := list(islice(texts_iterator, batch_size)):
# Take the corresponding metadata and id for each text in a batch
batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None
batch_ids = list(islice(ids_iterator, batch_size))
# Generate the embeddings for all the texts in a batch
batch_embeddings = await self._aembed_texts(batch_texts)
points = [
rest.PointStruct(
id=point_id,
vector=vector
if self.vector_name is None
else {self.vector_name: vector},
payload=payload,
)
for point_id, vector, payload in zip(
batch_ids,
batch_embeddings,
self._build_payloads(
batch_texts,
batch_metadatas,
self.content_payload_key,
self.metadata_payload_key,
),
)
]
yield batch_ids, points
@staticmethod
def _generate_clients(
location: Optional[str] = None,
url: Optional[str] = None,
port: Optional[int] = 6333,
grpc_port: int = 6334,
prefer_grpc: bool = False,
https: Optional[bool] = None,
api_key: Optional[str] = None,
prefix: Optional[str] = None,
timeout: Optional[float] = None,
host: Optional[str] = None,
path: Optional[str] = None,
**kwargs: Any,
) -> Tuple[Any, Any]:
from qdrant_client import AsyncQdrantClient, QdrantClient
sync_client = QdrantClient(
location=location,
url=url,
port=port,
grpc_port=grpc_port,
prefer_grpc=prefer_grpc,
https=https,
api_key=api_key,
prefix=prefix,
timeout=timeout,
host=host,
path=path,
**kwargs,
)
if location == ":memory:" or path is not None:
# Local Qdrant cannot co-exist with Sync and Async clients
# We fallback to sync operations in this case
async_client = None
else:
async_client = AsyncQdrantClient(
location=location,
url=url,
port=port,
grpc_port=grpc_port,
prefer_grpc=prefer_grpc,
https=https,
api_key=api_key,
prefix=prefix,
timeout=timeout,
host=host,
path=path,
**kwargs,
)
return sync_client, async_client