"""Wrapper around scikit-learn NearestNeighbors implementation.The vector store can be persisted in json, bson or parquet format."""importjsonimportmathimportosfromabcimportABC,abstractmethodfromtypingimportAny,Dict,Iterable,List,Literal,Optional,Tuple,Typefromuuidimportuuid4fromlangchain_core.documentsimportDocumentfromlangchain_core.embeddingsimportEmbeddingsfromlangchain_core.utilsimportguard_importfromlangchain_core.vectorstoresimportVectorStorefromlangchain_community.vectorstores.utilsimportmaximal_marginal_relevanceDEFAULT_K=4# Number of Documents to return.DEFAULT_FETCH_K=20# Number of Documents to initially fetch during MMR search.
[docs]classBaseSerializer(ABC):"""Base class for serializing data."""
[docs]defpersist(self)->None:ifself._serializerisNone:raiseSKLearnVectorStoreException("You must specify a persist_path on creation to persist the ""collection.")data={"ids":self._ids,"texts":self._texts,"metadatas":self._metadatas,"embeddings":self._embeddings,}self._serializer.save(data)
def_load(self)->None:ifself._serializerisNone:raiseSKLearnVectorStoreException("You must specify a persist_path on creation to load the ""collection.")data=self._serializer.load()self._embeddings=data["embeddings"]self._texts=data["texts"]self._metadatas=data["metadatas"]self._ids=data["ids"]self._update_neighbors()
def_update_neighbors(self)->None:iflen(self._embeddings)==0:raiseSKLearnVectorStoreException("No data was added to SKLearnVectorStore.")self._embeddings_np=self._np.asarray(self._embeddings)self._neighbors.fit(self._embeddings_np)self._neighbors_fitted=Truedef_similarity_index_search_with_score(self,query_embedding:List[float],*,k:int=DEFAULT_K,**kwargs:Any)->List[Tuple[int,float]]:"""Search k embeddings similar to the query embedding. Returns a list of (index, distance) tuples."""ifnotself._neighbors_fitted:raiseSKLearnVectorStoreException("No data was added to SKLearnVectorStore.")neigh_dists,neigh_idxs=self._neighbors.kneighbors([query_embedding],n_neighbors=k)returnlist(zip(neigh_idxs[0],neigh_dists[0]))
[docs]defmax_marginal_relevance_search_by_vector(self,embedding:List[float],k:int=DEFAULT_K,fetch_k:int=DEFAULT_FETCH_K,lambda_mult:float=0.5,**kwargs:Any,)->List[Document]:"""Return docs selected using the maximal marginal relevance. Maximal marginal relevance optimizes for similarity to query AND diversity among selected documents. Args: embedding: Embedding to look up documents similar to. k: Number of Documents to return. Defaults to 4. fetch_k: Number of Documents to fetch to pass to MMR algorithm. lambda_mult: Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. Returns: List of Documents selected by maximal marginal relevance. """indices_dists=self._similarity_index_search_with_score(embedding,k=fetch_k,**kwargs)indices,_=zip(*indices_dists)result_embeddings=self._embeddings_np[indices,]mmr_selected=maximal_marginal_relevance(self._np.array(embedding,dtype=self._np.float32),result_embeddings,k=k,lambda_mult=lambda_mult,)mmr_indices=[indices[i]foriinmmr_selected]return[Document(page_content=self._texts[idx],metadata={"id":self._ids[idx],**self._metadatas[idx]},)foridxinmmr_indices]
[docs]defmax_marginal_relevance_search(self,query:str,k:int=DEFAULT_K,fetch_k:int=DEFAULT_FETCH_K,lambda_mult:float=0.5,**kwargs:Any,)->List[Document]:"""Return docs selected using the maximal marginal relevance. Maximal marginal relevance optimizes for similarity to query AND diversity among selected documents. Args: query: Text to look up documents similar to. k: Number of Documents to return. Defaults to 4. fetch_k: Number of Documents to fetch to pass to MMR algorithm. lambda_mult: Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. Returns: List of Documents selected by maximal marginal relevance. """ifself._embedding_functionisNone:raiseValueError("For MMR search, you must specify an embedding function on creation.")embedding=self._embedding_function.embed_query(query)docs=self.max_marginal_relevance_search_by_vector(embedding,k,fetch_k,lambda_mul=lambda_mult)returndocs