Source code for langchain_community.vectorstores.sklearn

""" 封装了scikit-learn最近邻实现。

向量存储可以以json、bson或parquet格式持久化。
"""

import json
import math
import os
from abc import ABC, abstractmethod
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, Type
from uuid import uuid4

from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.utils import guard_import
from langchain_core.vectorstores import VectorStore

from langchain_community.vectorstores.utils import maximal_marginal_relevance

DEFAULT_K = 4  # Number of Documents to return.
DEFAULT_FETCH_K = 20  # Number of Documents to initially fetch during MMR search.


[docs]class BaseSerializer(ABC): """用于序列化数据的基类。"""
[docs] def __init__(self, persist_path: str) -> None: self.persist_path = persist_path
[docs] @classmethod @abstractmethod def extension(cls) -> str: """这个序列化器建议的文件扩展名(不带点)。"""
[docs] @abstractmethod def save(self, data: Any) -> None: """将数据保存到persist_path路径中。"""
[docs] @abstractmethod def load(self) -> Any: """从persist_path加载数据"""
[docs]class JsonSerializer(BaseSerializer): """使用Python标准库中的json包将数据序列化为JSON格式。"""
[docs] @classmethod def extension(cls) -> str: return "json"
[docs] def save(self, data: Any) -> None: with open(self.persist_path, "w") as fp: json.dump(data, fp)
[docs] def load(self) -> Any: with open(self.persist_path, "r") as fp: return json.load(fp)
[docs]class BsonSerializer(BaseSerializer): """使用`bson` python包将数据序列化为二进制JSON。"""
[docs] def __init__(self, persist_path: str) -> None: super().__init__(persist_path) self.bson = guard_import("bson")
[docs] @classmethod def extension(cls) -> str: return "bson"
[docs] def save(self, data: Any) -> None: with open(self.persist_path, "wb") as fp: fp.write(self.bson.dumps(data))
[docs] def load(self) -> Any: with open(self.persist_path, "rb") as fp: return self.bson.loads(fp.read())
[docs]class ParquetSerializer(BaseSerializer): """使用`pyarrow`包将数据序列化为`Apache Parquet`格式。"""
[docs] def __init__(self, persist_path: str) -> None: super().__init__(persist_path) self.pd = guard_import("pandas") self.pa = guard_import("pyarrow") self.pq = guard_import("pyarrow.parquet")
[docs] @classmethod def extension(cls) -> str: return "parquet"
[docs] def save(self, data: Any) -> None: df = self.pd.DataFrame(data) table = self.pa.Table.from_pandas(df) if os.path.exists(self.persist_path): backup_path = str(self.persist_path) + "-backup" os.rename(self.persist_path, backup_path) try: self.pq.write_table(table, self.persist_path) except Exception as exc: os.rename(backup_path, self.persist_path) raise exc else: os.remove(backup_path) else: self.pq.write_table(table, self.persist_path)
[docs] def load(self) -> Any: table = self.pq.read_table(self.persist_path) df = table.to_pandas() return {col: series.tolist() for col, series in df.items()}
SERIALIZER_MAP: Dict[str, Type[BaseSerializer]] = { "json": JsonSerializer, "bson": BsonSerializer, "parquet": ParquetSerializer, }
[docs]class SKLearnVectorStoreException(RuntimeError): """SKLearnVectorStore引发的异常。""" pass
[docs]class SKLearnVectorStore(VectorStore): """基于`scikit-learn`库`NearestNeighbors`的简单内存向量存储。"""
[docs] def __init__( self, embedding: Embeddings, *, persist_path: Optional[str] = None, serializer: Literal["json", "bson", "parquet"] = "json", metric: str = "cosine", **kwargs: Any, ) -> None: np = guard_import("numpy") sklearn_neighbors = guard_import("sklearn.neighbors", pip_name="scikit-learn") # non-persistent properties self._np = np self._neighbors = sklearn_neighbors.NearestNeighbors(metric=metric, **kwargs) self._neighbors_fitted = False self._embedding_function = embedding self._persist_path = persist_path self._serializer: Optional[BaseSerializer] = None if self._persist_path is not None: serializer_cls = SERIALIZER_MAP[serializer] self._serializer = serializer_cls(persist_path=self._persist_path) # data properties self._embeddings: List[List[float]] = [] self._texts: List[str] = [] self._metadatas: List[dict] = [] self._ids: List[str] = [] # cache properties self._embeddings_np: Any = np.asarray([]) if self._persist_path is not None and os.path.isfile(self._persist_path): self._load()
@property def embeddings(self) -> Embeddings: return self._embedding_function
[docs] def persist(self) -> None: if self._serializer is None: raise SKLearnVectorStoreException( "You must specify a persist_path on creation to persist the " "collection." ) data = { "ids": self._ids, "texts": self._texts, "metadatas": self._metadatas, "embeddings": self._embeddings, } self._serializer.save(data)
def _load(self) -> None: if self._serializer is None: raise SKLearnVectorStoreException( "You must specify a persist_path on creation to load the " "collection." ) data = self._serializer.load() self._embeddings = data["embeddings"] self._texts = data["texts"] self._metadatas = data["metadatas"] self._ids = data["ids"] self._update_neighbors()
[docs] def add_texts( self, texts: Iterable[str], metadatas: Optional[List[dict]] = None, ids: Optional[List[str]] = None, **kwargs: Any, ) -> List[str]: _texts = list(texts) _ids = ids or [str(uuid4()) for _ in _texts] self._texts.extend(_texts) self._embeddings.extend(self._embedding_function.embed_documents(_texts)) self._metadatas.extend(metadatas or ([{}] * len(_texts))) self._ids.extend(_ids) self._update_neighbors() return _ids
def _update_neighbors(self) -> None: if len(self._embeddings) == 0: raise SKLearnVectorStoreException( "No data was added to SKLearnVectorStore." ) self._embeddings_np = self._np.asarray(self._embeddings) self._neighbors.fit(self._embeddings_np) self._neighbors_fitted = True def _similarity_index_search_with_score( self, query_embedding: List[float], *, k: int = DEFAULT_K, **kwargs: Any ) -> List[Tuple[int, float]]: """搜索与查询嵌入相似的k个嵌入。返回一个(index, distance)元组的列表。 """ if not self._neighbors_fitted: raise SKLearnVectorStoreException( "No data was added to SKLearnVectorStore." ) neigh_dists, neigh_idxs = self._neighbors.kneighbors( [query_embedding], n_neighbors=k ) return list(zip(neigh_idxs[0], neigh_dists[0]))
[docs] def similarity_search_with_score( self, query: str, *, k: int = DEFAULT_K, **kwargs: Any ) -> List[Tuple[Document, float]]: query_embedding = self._embedding_function.embed_query(query) indices_dists = self._similarity_index_search_with_score( query_embedding, k=k, **kwargs ) return [ ( Document( page_content=self._texts[idx], metadata={"id": self._ids[idx], **self._metadatas[idx]}, ), dist, ) for idx, dist in indices_dists ]
def _similarity_search_with_relevance_scores( self, query: str, k: int = DEFAULT_K, **kwargs: Any ) -> List[Tuple[Document, float]]: docs_dists = self.similarity_search_with_score(query, k=k, **kwargs) docs, dists = zip(*docs_dists) scores = [1 / math.exp(dist) for dist in dists] return list(zip(list(docs), scores))
[docs] def max_marginal_relevance_search_by_vector( self, embedding: List[float], k: int = DEFAULT_K, fetch_k: int = DEFAULT_FETCH_K, lambda_mult: float = 0.5, **kwargs: Any, ) -> List[Document]: """返回使用最大边际相关性选择的文档。 最大边际相关性优化了与查询的相似性和所选文档之间的多样性。 参数: embedding: 查找与之相似的文档的嵌入。 k: 要返回的文档数量。默认为4。 fetch_k: 要获取以传递给MMR算法的文档数量。 lambda_mult: 介于0和1之间的数字,确定结果之间多样性的程度,0对应最大多样性,1对应最小多样性。 默认为0.5。 返回: 通过最大边际相关性选择的文档列表。 """ indices_dists = self._similarity_index_search_with_score( embedding, k=fetch_k, **kwargs ) indices, _ = zip(*indices_dists) result_embeddings = self._embeddings_np[indices,] mmr_selected = maximal_marginal_relevance( self._np.array(embedding, dtype=self._np.float32), result_embeddings, k=k, lambda_mult=lambda_mult, ) mmr_indices = [indices[i] for i in mmr_selected] return [ Document( page_content=self._texts[idx], metadata={"id": self._ids[idx], **self._metadatas[idx]}, ) for idx in mmr_indices ]
[docs] @classmethod def from_texts( cls, texts: List[str], embedding: Embeddings, metadatas: Optional[List[dict]] = None, ids: Optional[List[str]] = None, persist_path: Optional[str] = None, **kwargs: Any, ) -> "SKLearnVectorStore": vs = SKLearnVectorStore(embedding, persist_path=persist_path, **kwargs) vs.add_texts(texts, metadatas=metadatas, ids=ids) return vs