import logging
from typing import (
TYPE_CHECKING,
Any,
Iterable,
List,
Optional,
Tuple,
Type,
TypeVar,
cast,
)
from uuid import uuid4
import numpy as np
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.utils import get_from_env
from langchain_core.vectorstores import VectorStore
from langchain_community.vectorstores.utils import (
DistanceStrategy,
maximal_marginal_relevance,
)
VST = TypeVar("VST", bound="VectorStore")
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from momento import PreviewVectorIndexClient
[docs]class MomentoVectorIndex(VectorStore):
"""`Momento Vector Index`(MVI)向量存储。
Momento Vector Index 是一个无服务器向量索引,可用于存储和搜索向量。要使用它,您应该已安装``momento`` python包。
示例:
.. code-block:: python
from langchain_community.embeddings import OpenAIEmbeddings
from langchain_community.vectorstores import MomentoVectorIndex
from momento import (
CredentialProvider,
PreviewVectorIndexClient,
VectorIndexConfigurations,
)
vectorstore = MomentoVectorIndex(
embedding=OpenAIEmbeddings(),
client=PreviewVectorIndexClient(
VectorIndexConfigurations.Default.latest(),
credential_provider=CredentialProvider.from_environment_variable(
"MOMENTO_API_KEY"
),
),
index_name="my-index",
)"""
[docs] def __init__(
self,
embedding: Embeddings,
client: "PreviewVectorIndexClient",
index_name: str = "default",
distance_strategy: DistanceStrategy = DistanceStrategy.COSINE,
text_field: str = "text",
ensure_index_exists: bool = True,
**kwargs: Any,
):
"""初始化由Momento Vector Index支持的Vector Store。
参数:
embedding (Embeddings): 要使用的嵌入函数。
configuration (VectorIndexConfiguration): 用于初始化Vector Index的配置。
credential_provider (CredentialProvider): 用于验证Vector Index的凭据提供程序。
index_name (str, optional): 存储文档的索引名称。默认为"default"。
distance_strategy (DistanceStrategy, optional): 要使用的距离策略。如果选择DistanceStrategy.EUCLIDEAN_DISTANCE,Momento将使用平方欧氏距离。默认为DistanceStrategy.COSINE。
text_field (str, optional): 存储原始文本的元数据字段名称。默认为"text"。
ensure_index_exists (bool, optional): 在向其添加文档之前是否确保索引存在。默认为True。
"""
try:
from momento import PreviewVectorIndexClient
except ImportError:
raise ImportError(
"Could not import momento python package. "
"Please install it with `pip install momento`."
)
self._client: PreviewVectorIndexClient = client
self._embedding = embedding
self.index_name = index_name
self.__validate_distance_strategy(distance_strategy)
self.distance_strategy = distance_strategy
self.text_field = text_field
self._ensure_index_exists = ensure_index_exists
@staticmethod
def __validate_distance_strategy(distance_strategy: DistanceStrategy) -> None:
if distance_strategy not in [
DistanceStrategy.COSINE,
DistanceStrategy.MAX_INNER_PRODUCT,
DistanceStrategy.MAX_INNER_PRODUCT,
]:
raise ValueError(f"Distance strategy {distance_strategy} not implemented.")
@property
def embeddings(self) -> Embeddings:
return self._embedding
def _create_index_if_not_exists(self, num_dimensions: int) -> bool:
"""如果索引不存在,则创建索引。"""
from momento.requests.vector_index import SimilarityMetric
from momento.responses.vector_index import CreateIndex
similarity_metric = None
if self.distance_strategy == DistanceStrategy.COSINE:
similarity_metric = SimilarityMetric.COSINE_SIMILARITY
elif self.distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT:
similarity_metric = SimilarityMetric.INNER_PRODUCT
elif self.distance_strategy == DistanceStrategy.EUCLIDEAN_DISTANCE:
similarity_metric = SimilarityMetric.EUCLIDEAN_SIMILARITY
else:
logger.error(f"Distance strategy {self.distance_strategy} not implemented.")
raise ValueError(
f"Distance strategy {self.distance_strategy} not implemented."
)
response = self._client.create_index(
self.index_name, num_dimensions, similarity_metric
)
if isinstance(response, CreateIndex.Success):
return True
elif isinstance(response, CreateIndex.IndexAlreadyExists):
return False
elif isinstance(response, CreateIndex.Error):
logger.error(f"Error creating index: {response.inner_exception}")
raise response.inner_exception
else:
logger.error(f"Unexpected response: {response}")
raise Exception(f"Unexpected response: {response}")
[docs] def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> List[str]:
"""运行更多的文本通过嵌入并添加到向量存储中。
参数:
texts (Iterable[str]): 要添加到向量存储中的字符串的可迭代对象。
metadatas (Optional[List[dict]]): 与文本相关联的元数据的可选列表。
kwargs (Any): 其他可选参数。具体包括:
- ids (List[str], optional): 用于文本的id列表。
默认为None,此时将生成uuid。
返回:
List[str]: 将文本添加到向量存储中后的id列表。
"""
from momento.requests.vector_index import Item
from momento.responses.vector_index import UpsertItemBatch
texts = list(texts)
if len(texts) == 0:
return []
if metadatas is not None:
for metadata, text in zip(metadatas, texts):
metadata[self.text_field] = text
else:
metadatas = [{self.text_field: text} for text in texts]
try:
embeddings = self._embedding.embed_documents(texts)
except NotImplementedError:
embeddings = [self._embedding.embed_query(x) for x in texts]
# 如果索引不存在,则创建索引。
# We assume that if it does exist, then it was created with the desired number
# of dimensions and similarity metric.
if self._ensure_index_exists:
self._create_index_if_not_exists(len(embeddings[0]))
if "ids" in kwargs:
ids = kwargs["ids"]
if len(ids) != len(embeddings):
raise ValueError("Number of ids must match number of texts")
else:
ids = [str(uuid4()) for _ in range(len(embeddings))]
batch_size = 128
for i in range(0, len(embeddings), batch_size):
start = i
end = min(i + batch_size, len(embeddings))
items = [
Item(id=id, vector=vector, metadata=metadata)
for id, vector, metadata in zip(
ids[start:end],
embeddings[start:end],
metadatas[start:end],
)
]
response = self._client.upsert_item_batch(self.index_name, items)
if isinstance(response, UpsertItemBatch.Success):
pass
elif isinstance(response, UpsertItemBatch.Error):
raise response.inner_exception
else:
raise Exception(f"Unexpected response: {response}")
return ids
[docs] def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
"""根据向量ID删除。
参数:
ids(List[str]):要删除的ID列表。
kwargs(Any):其他可选参数(未使用)
返回:
Optional[bool]:如果删除成功则为True,否则为False,如果未实现则为None。
"""
from momento.responses.vector_index import DeleteItemBatch
if ids is None:
return True
response = self._client.delete_item_batch(self.index_name, ids)
return isinstance(response, DeleteItemBatch.Success)
[docs] def similarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]:
"""搜索与查询字符串相似的文档。
参数:
query(str):要搜索的查询字符串。
k(int,可选):要返回的结果数量。默认为4。
返回:
List[Document]:与查询相似的文档列表。
"""
res = self.similarity_search_with_score(query=query, k=k, **kwargs)
return [doc for doc, _ in res]
[docs] def similarity_search_with_score(
self,
query: str,
k: int = 4,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""搜索与查询字符串相似的文档。
参数:
query (str): 要搜索的查询字符串。
k (int, optional): 要返回的结果数量。默认为4。
kwargs (Any): 向量存储特定的搜索参数。以下参数将被转发给Momento向量索引:
- top_k (int, optional): 要返回的结果数量。
返回:
List[Tuple[Document, float]]: 一个元组列表,形式为(Document, score)。
"""
embedding = self._embedding.embed_query(query)
results = self.similarity_search_with_score_by_vector(
embedding=embedding, k=k, **kwargs
)
return results
[docs] def similarity_search_with_score_by_vector(
self,
embedding: List[float],
k: int = 4,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""搜索与查询向量相似的文档。
参数:
embedding(List[float]):要搜索的查询向量。
k(int,可选):要返回的结果数量。默认为4。
kwargs(Any):向量存储特定的搜索参数。以下参数将被转发到Momento向量索引:
- top_k(int,可选):要返回的结果数量。
返回:
List[Tuple[Document, float]]:形式为(Document,score)的元组列表。
"""
from momento.requests.vector_index import ALL_METADATA
from momento.responses.vector_index import Search
if "top_k" in kwargs:
k = kwargs["k"]
filter_expression = kwargs.get("filter_expression", None)
response = self._client.search(
self.index_name,
embedding,
top_k=k,
metadata_fields=ALL_METADATA,
filter_expression=filter_expression,
)
if not isinstance(response, Search.Success):
return []
results = []
for hit in response.hits:
text = cast(str, hit.metadata.pop(self.text_field))
doc = Document(page_content=text, metadata=hit.metadata)
pair = (doc, hit.score)
results.append(pair)
return results
[docs] def similarity_search_by_vector(
self, embedding: List[float], k: int = 4, **kwargs: Any
) -> List[Document]:
"""搜索与查询向量相似的文档。
参数:
embedding (List[float]): 要搜索的查询向量。
k (int, optional): 要返回的结果数量。默认为4。
返回:
List[Document]: 与查询相似的文档列表。
"""
results = self.similarity_search_with_score_by_vector(
embedding=embedding, k=k, **kwargs
)
return [doc for doc, _ in results]
[docs] def max_marginal_relevance_search_by_vector(
self,
embedding: List[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> List[Document]:
"""返回使用最大边际相关性选择的文档。
最大边际相关性优化了与查询的相似性和所选文档之间的多样性。
参数:
embedding:要查找相似文档的嵌入。
k:要返回的文档数量。默认为4。
fetch_k:要获取并传递给MMR算法的文档数量。
lambda_mult:0到1之间的数字,确定结果之间多样性的程度,
0对应最大多样性,1对应最小多样性。
默认为0.5。
返回:
由最大边际相关性选择的文档列表。
"""
from momento.requests.vector_index import ALL_METADATA
from momento.responses.vector_index import SearchAndFetchVectors
filter_expression = kwargs.get("filter_expression", None)
response = self._client.search_and_fetch_vectors(
self.index_name,
embedding,
top_k=fetch_k,
metadata_fields=ALL_METADATA,
filter_expression=filter_expression,
)
if isinstance(response, SearchAndFetchVectors.Success):
pass
elif isinstance(response, SearchAndFetchVectors.Error):
logger.error(f"Error searching and fetching vectors: {response}")
return []
else:
logger.error(f"Unexpected response: {response}")
raise Exception(f"Unexpected response: {response}")
mmr_selected = maximal_marginal_relevance(
query_embedding=np.array([embedding], dtype=np.float32),
embedding_list=[hit.vector for hit in response.hits],
lambda_mult=lambda_mult,
k=k,
)
selected = [response.hits[i].metadata for i in mmr_selected]
return [
Document(page_content=metadata.pop(self.text_field, ""), metadata=metadata) # type: ignore # noqa: E501
for metadata in selected
]
[docs] def max_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> List[Document]:
"""返回使用最大边际相关性选择的文档。
最大边际相关性优化了与查询的相似性和所选文档之间的多样性。
参数:
query:要查找相似文档的文本。
k:要返回的文档数量。默认为4。
fetch_k:要获取以传递给MMR算法的文档数量。
lambda_mult:0到1之间的数字,确定结果之间多样性的程度,其中0对应最大多样性,1对应最小多样性。默认为0.5。
返回:
由最大边际相关性选择的文档列表。
"""
embedding = self._embedding.embed_query(query)
return self.max_marginal_relevance_search_by_vector(
embedding, k, fetch_k, lambda_mult, **kwargs
)
[docs] @classmethod
def from_texts(
cls: Type[VST],
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> VST:
"""返回从文本和嵌入初始化的向量存储。
参数:
cls(Type[VST]):用于初始化向量存储的向量存储类。
texts(List[str]):用于初始化向量存储的文本。
embedding(Embeddings):要使用的嵌入函数。
metadatas(Optional[List[dict],可选):与文本相关联的元数据。默认为None。
kwargs(Any):向量存储特定参数。以下参数将被转发到向量存储构造函数并且是必需的:
- index_name(str,可选):存储文档的索引名称。默认为"default"。
- text_field(str,可选):存储原始文本的元数据字段名称。默认为"text"。
- distance_strategy(DistanceStrategy,可选):要使用的距离策略。默认为DistanceStrategy.COSINE。如果选择DistanceStrategy.EUCLIDEAN_DISTANCE,Momento将使用平方欧氏距离。
- ensure_index_exists(bool,可选):在向其添加文档之前是否确保索引存在。默认为True。
此外,您可以传入客户端或API密钥
- client(PreviewVectorIndexClient):要使用的Momento向量索引客户端。
- api_key(Optional[str]):用于初始化向量索引的配置。默认为None。如果为None,则配置将从环境变量`MOMENTO_API_KEY`初始化。
返回:
VST:从文本和嵌入初始化的Momento向量索引向量存储。
"""
from momento import (
CredentialProvider,
PreviewVectorIndexClient,
VectorIndexConfigurations,
)
if "client" in kwargs:
client = kwargs.pop("client")
else:
supplied_api_key = kwargs.pop("api_key", None)
api_key = supplied_api_key or get_from_env("api_key", "MOMENTO_API_KEY")
client = PreviewVectorIndexClient(
configuration=VectorIndexConfigurations.Default.latest(),
credential_provider=CredentialProvider.from_string(api_key),
)
vector_db = cls(embedding=embedding, client=client, **kwargs) # type: ignore
vector_db.add_texts(texts=texts, metadatas=metadatas, **kwargs)
return vector_db