Source code for langchain_community.vectorstores.momento_vector_index

import logging
from typing import (
    TYPE_CHECKING,
    Any,
    Iterable,
    List,
    Optional,
    Tuple,
    Type,
    TypeVar,
    cast,
)
from uuid import uuid4

import numpy as np
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.utils import get_from_env
from langchain_core.vectorstores import VectorStore

from langchain_community.vectorstores.utils import (
    DistanceStrategy,
    maximal_marginal_relevance,
)

VST = TypeVar("VST", bound="VectorStore")

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
    from momento import PreviewVectorIndexClient


[docs]class MomentoVectorIndex(VectorStore): """`Momento Vector Index`(MVI)向量存储。 Momento Vector Index 是一个无服务器向量索引,可用于存储和搜索向量。要使用它,您应该已安装``momento`` python包。 示例: .. code-block:: python from langchain_community.embeddings import OpenAIEmbeddings from langchain_community.vectorstores import MomentoVectorIndex from momento import ( CredentialProvider, PreviewVectorIndexClient, VectorIndexConfigurations, ) vectorstore = MomentoVectorIndex( embedding=OpenAIEmbeddings(), client=PreviewVectorIndexClient( VectorIndexConfigurations.Default.latest(), credential_provider=CredentialProvider.from_environment_variable( "MOMENTO_API_KEY" ), ), index_name="my-index", )"""
[docs] def __init__( self, embedding: Embeddings, client: "PreviewVectorIndexClient", index_name: str = "default", distance_strategy: DistanceStrategy = DistanceStrategy.COSINE, text_field: str = "text", ensure_index_exists: bool = True, **kwargs: Any, ): """初始化由Momento Vector Index支持的Vector Store。 参数: embedding (Embeddings): 要使用的嵌入函数。 configuration (VectorIndexConfiguration): 用于初始化Vector Index的配置。 credential_provider (CredentialProvider): 用于验证Vector Index的凭据提供程序。 index_name (str, optional): 存储文档的索引名称。默认为"default"。 distance_strategy (DistanceStrategy, optional): 要使用的距离策略。如果选择DistanceStrategy.EUCLIDEAN_DISTANCE,Momento将使用平方欧氏距离。默认为DistanceStrategy.COSINE。 text_field (str, optional): 存储原始文本的元数据字段名称。默认为"text"。 ensure_index_exists (bool, optional): 在向其添加文档之前是否确保索引存在。默认为True。 """ try: from momento import PreviewVectorIndexClient except ImportError: raise ImportError( "Could not import momento python package. " "Please install it with `pip install momento`." ) self._client: PreviewVectorIndexClient = client self._embedding = embedding self.index_name = index_name self.__validate_distance_strategy(distance_strategy) self.distance_strategy = distance_strategy self.text_field = text_field self._ensure_index_exists = ensure_index_exists
@staticmethod def __validate_distance_strategy(distance_strategy: DistanceStrategy) -> None: if distance_strategy not in [ DistanceStrategy.COSINE, DistanceStrategy.MAX_INNER_PRODUCT, DistanceStrategy.MAX_INNER_PRODUCT, ]: raise ValueError(f"Distance strategy {distance_strategy} not implemented.") @property def embeddings(self) -> Embeddings: return self._embedding def _create_index_if_not_exists(self, num_dimensions: int) -> bool: """如果索引不存在,则创建索引。""" from momento.requests.vector_index import SimilarityMetric from momento.responses.vector_index import CreateIndex similarity_metric = None if self.distance_strategy == DistanceStrategy.COSINE: similarity_metric = SimilarityMetric.COSINE_SIMILARITY elif self.distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT: similarity_metric = SimilarityMetric.INNER_PRODUCT elif self.distance_strategy == DistanceStrategy.EUCLIDEAN_DISTANCE: similarity_metric = SimilarityMetric.EUCLIDEAN_SIMILARITY else: logger.error(f"Distance strategy {self.distance_strategy} not implemented.") raise ValueError( f"Distance strategy {self.distance_strategy} not implemented." ) response = self._client.create_index( self.index_name, num_dimensions, similarity_metric ) if isinstance(response, CreateIndex.Success): return True elif isinstance(response, CreateIndex.IndexAlreadyExists): return False elif isinstance(response, CreateIndex.Error): logger.error(f"Error creating index: {response.inner_exception}") raise response.inner_exception else: logger.error(f"Unexpected response: {response}") raise Exception(f"Unexpected response: {response}")
[docs] def add_texts( self, texts: Iterable[str], metadatas: Optional[List[dict]] = None, **kwargs: Any, ) -> List[str]: """运行更多的文本通过嵌入并添加到向量存储中。 参数: texts (Iterable[str]): 要添加到向量存储中的字符串的可迭代对象。 metadatas (Optional[List[dict]]): 与文本相关联的元数据的可选列表。 kwargs (Any): 其他可选参数。具体包括: - ids (List[str], optional): 用于文本的id列表。 默认为None,此时将生成uuid。 返回: List[str]: 将文本添加到向量存储中后的id列表。 """ from momento.requests.vector_index import Item from momento.responses.vector_index import UpsertItemBatch texts = list(texts) if len(texts) == 0: return [] if metadatas is not None: for metadata, text in zip(metadatas, texts): metadata[self.text_field] = text else: metadatas = [{self.text_field: text} for text in texts] try: embeddings = self._embedding.embed_documents(texts) except NotImplementedError: embeddings = [self._embedding.embed_query(x) for x in texts] # 如果索引不存在,则创建索引。 # We assume that if it does exist, then it was created with the desired number # of dimensions and similarity metric. if self._ensure_index_exists: self._create_index_if_not_exists(len(embeddings[0])) if "ids" in kwargs: ids = kwargs["ids"] if len(ids) != len(embeddings): raise ValueError("Number of ids must match number of texts") else: ids = [str(uuid4()) for _ in range(len(embeddings))] batch_size = 128 for i in range(0, len(embeddings), batch_size): start = i end = min(i + batch_size, len(embeddings)) items = [ Item(id=id, vector=vector, metadata=metadata) for id, vector, metadata in zip( ids[start:end], embeddings[start:end], metadatas[start:end], ) ] response = self._client.upsert_item_batch(self.index_name, items) if isinstance(response, UpsertItemBatch.Success): pass elif isinstance(response, UpsertItemBatch.Error): raise response.inner_exception else: raise Exception(f"Unexpected response: {response}") return ids
[docs] def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]: """根据向量ID删除。 参数: ids(List[str]):要删除的ID列表。 kwargs(Any):其他可选参数(未使用) 返回: Optional[bool]:如果删除成功则为True,否则为False,如果未实现则为None。 """ from momento.responses.vector_index import DeleteItemBatch if ids is None: return True response = self._client.delete_item_batch(self.index_name, ids) return isinstance(response, DeleteItemBatch.Success)
[docs] def similarity_search_with_score( self, query: str, k: int = 4, **kwargs: Any, ) -> List[Tuple[Document, float]]: """搜索与查询字符串相似的文档。 参数: query (str): 要搜索的查询字符串。 k (int, optional): 要返回的结果数量。默认为4。 kwargs (Any): 向量存储特定的搜索参数。以下参数将被转发给Momento向量索引: - top_k (int, optional): 要返回的结果数量。 返回: List[Tuple[Document, float]]: 一个元组列表,形式为(Document, score)。 """ embedding = self._embedding.embed_query(query) results = self.similarity_search_with_score_by_vector( embedding=embedding, k=k, **kwargs ) return results
[docs] def similarity_search_with_score_by_vector( self, embedding: List[float], k: int = 4, **kwargs: Any, ) -> List[Tuple[Document, float]]: """搜索与查询向量相似的文档。 参数: embedding(List[float]):要搜索的查询向量。 k(int,可选):要返回的结果数量。默认为4。 kwargs(Any):向量存储特定的搜索参数。以下参数将被转发到Momento向量索引: - top_k(int,可选):要返回的结果数量。 返回: List[Tuple[Document, float]]:形式为(Document,score)的元组列表。 """ from momento.requests.vector_index import ALL_METADATA from momento.responses.vector_index import Search if "top_k" in kwargs: k = kwargs["k"] filter_expression = kwargs.get("filter_expression", None) response = self._client.search( self.index_name, embedding, top_k=k, metadata_fields=ALL_METADATA, filter_expression=filter_expression, ) if not isinstance(response, Search.Success): return [] results = [] for hit in response.hits: text = cast(str, hit.metadata.pop(self.text_field)) doc = Document(page_content=text, metadata=hit.metadata) pair = (doc, hit.score) results.append(pair) return results
[docs] def similarity_search_by_vector( self, embedding: List[float], k: int = 4, **kwargs: Any ) -> List[Document]: """搜索与查询向量相似的文档。 参数: embedding (List[float]): 要搜索的查询向量。 k (int, optional): 要返回的结果数量。默认为4。 返回: List[Document]: 与查询相似的文档列表。 """ results = self.similarity_search_with_score_by_vector( embedding=embedding, k=k, **kwargs ) return [doc for doc, _ in results]
[docs] def max_marginal_relevance_search_by_vector( self, embedding: List[float], k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, **kwargs: Any, ) -> List[Document]: """返回使用最大边际相关性选择的文档。 最大边际相关性优化了与查询的相似性和所选文档之间的多样性。 参数: embedding:要查找相似文档的嵌入。 k:要返回的文档数量。默认为4。 fetch_k:要获取并传递给MMR算法的文档数量。 lambda_mult:0到1之间的数字,确定结果之间多样性的程度, 0对应最大多样性,1对应最小多样性。 默认为0.5。 返回: 由最大边际相关性选择的文档列表。 """ from momento.requests.vector_index import ALL_METADATA from momento.responses.vector_index import SearchAndFetchVectors filter_expression = kwargs.get("filter_expression", None) response = self._client.search_and_fetch_vectors( self.index_name, embedding, top_k=fetch_k, metadata_fields=ALL_METADATA, filter_expression=filter_expression, ) if isinstance(response, SearchAndFetchVectors.Success): pass elif isinstance(response, SearchAndFetchVectors.Error): logger.error(f"Error searching and fetching vectors: {response}") return [] else: logger.error(f"Unexpected response: {response}") raise Exception(f"Unexpected response: {response}") mmr_selected = maximal_marginal_relevance( query_embedding=np.array([embedding], dtype=np.float32), embedding_list=[hit.vector for hit in response.hits], lambda_mult=lambda_mult, k=k, ) selected = [response.hits[i].metadata for i in mmr_selected] return [ Document(page_content=metadata.pop(self.text_field, ""), metadata=metadata) # type: ignore # noqa: E501 for metadata in selected ]
[docs] @classmethod def from_texts( cls: Type[VST], texts: List[str], embedding: Embeddings, metadatas: Optional[List[dict]] = None, **kwargs: Any, ) -> VST: """返回从文本和嵌入初始化的向量存储。 参数: cls(Type[VST]):用于初始化向量存储的向量存储类。 texts(List[str]):用于初始化向量存储的文本。 embedding(Embeddings):要使用的嵌入函数。 metadatas(Optional[List[dict],可选):与文本相关联的元数据。默认为None。 kwargs(Any):向量存储特定参数。以下参数将被转发到向量存储构造函数并且是必需的: - index_name(str,可选):存储文档的索引名称。默认为"default"。 - text_field(str,可选):存储原始文本的元数据字段名称。默认为"text"。 - distance_strategy(DistanceStrategy,可选):要使用的距离策略。默认为DistanceStrategy.COSINE。如果选择DistanceStrategy.EUCLIDEAN_DISTANCE,Momento将使用平方欧氏距离。 - ensure_index_exists(bool,可选):在向其添加文档之前是否确保索引存在。默认为True。 此外,您可以传入客户端或API密钥 - client(PreviewVectorIndexClient):要使用的Momento向量索引客户端。 - api_key(Optional[str]):用于初始化向量索引的配置。默认为None。如果为None,则配置将从环境变量`MOMENTO_API_KEY`初始化。 返回: VST:从文本和嵌入初始化的Momento向量索引向量存储。 """ from momento import ( CredentialProvider, PreviewVectorIndexClient, VectorIndexConfigurations, ) if "client" in kwargs: client = kwargs.pop("client") else: supplied_api_key = kwargs.pop("api_key", None) api_key = supplied_api_key or get_from_env("api_key", "MOMENTO_API_KEY") client = PreviewVectorIndexClient( configuration=VectorIndexConfigurations.Default.latest(), credential_provider=CredentialProvider.from_string(api_key), ) vector_db = cls(embedding=embedding, client=client, **kwargs) # type: ignore vector_db.add_texts(texts=texts, metadatas=metadatas, **kwargs) return vector_db