""" 封装了scikit-learn最近邻实现。
向量存储可以以json、bson或parquet格式持久化。
"""
import json
import math
import os
from abc import ABC, abstractmethod
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, Type
from uuid import uuid4
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.utils import guard_import
from langchain_core.vectorstores import VectorStore
from langchain_community.vectorstores.utils import maximal_marginal_relevance
DEFAULT_K = 4 # Number of Documents to return.
DEFAULT_FETCH_K = 20 # Number of Documents to initially fetch during MMR search.
[docs]class BaseSerializer(ABC):
"""用于序列化数据的基类。"""
[docs] def __init__(self, persist_path: str) -> None:
self.persist_path = persist_path
[docs] @classmethod
@abstractmethod
def extension(cls) -> str:
"""这个序列化器建议的文件扩展名(不带点)。"""
[docs] @abstractmethod
def save(self, data: Any) -> None:
"""将数据保存到persist_path路径中。"""
[docs] @abstractmethod
def load(self) -> Any:
"""从persist_path加载数据"""
[docs]class JsonSerializer(BaseSerializer):
"""使用Python标准库中的json包将数据序列化为JSON格式。"""
[docs] @classmethod
def extension(cls) -> str:
return "json"
[docs] def save(self, data: Any) -> None:
with open(self.persist_path, "w") as fp:
json.dump(data, fp)
[docs] def load(self) -> Any:
with open(self.persist_path, "r") as fp:
return json.load(fp)
[docs]class BsonSerializer(BaseSerializer):
"""使用`bson` python包将数据序列化为二进制JSON。"""
[docs] def __init__(self, persist_path: str) -> None:
super().__init__(persist_path)
self.bson = guard_import("bson")
[docs] @classmethod
def extension(cls) -> str:
return "bson"
[docs] def save(self, data: Any) -> None:
with open(self.persist_path, "wb") as fp:
fp.write(self.bson.dumps(data))
[docs] def load(self) -> Any:
with open(self.persist_path, "rb") as fp:
return self.bson.loads(fp.read())
[docs]class ParquetSerializer(BaseSerializer):
"""使用`pyarrow`包将数据序列化为`Apache Parquet`格式。"""
[docs] def __init__(self, persist_path: str) -> None:
super().__init__(persist_path)
self.pd = guard_import("pandas")
self.pa = guard_import("pyarrow")
self.pq = guard_import("pyarrow.parquet")
[docs] @classmethod
def extension(cls) -> str:
return "parquet"
[docs] def save(self, data: Any) -> None:
df = self.pd.DataFrame(data)
table = self.pa.Table.from_pandas(df)
if os.path.exists(self.persist_path):
backup_path = str(self.persist_path) + "-backup"
os.rename(self.persist_path, backup_path)
try:
self.pq.write_table(table, self.persist_path)
except Exception as exc:
os.rename(backup_path, self.persist_path)
raise exc
else:
os.remove(backup_path)
else:
self.pq.write_table(table, self.persist_path)
[docs] def load(self) -> Any:
table = self.pq.read_table(self.persist_path)
df = table.to_pandas()
return {col: series.tolist() for col, series in df.items()}
SERIALIZER_MAP: Dict[str, Type[BaseSerializer]] = {
"json": JsonSerializer,
"bson": BsonSerializer,
"parquet": ParquetSerializer,
}
[docs]class SKLearnVectorStoreException(RuntimeError):
"""SKLearnVectorStore引发的异常。"""
pass
[docs]class SKLearnVectorStore(VectorStore):
"""基于`scikit-learn`库`NearestNeighbors`的简单内存向量存储。"""
[docs] def __init__(
self,
embedding: Embeddings,
*,
persist_path: Optional[str] = None,
serializer: Literal["json", "bson", "parquet"] = "json",
metric: str = "cosine",
**kwargs: Any,
) -> None:
np = guard_import("numpy")
sklearn_neighbors = guard_import("sklearn.neighbors", pip_name="scikit-learn")
# non-persistent properties
self._np = np
self._neighbors = sklearn_neighbors.NearestNeighbors(metric=metric, **kwargs)
self._neighbors_fitted = False
self._embedding_function = embedding
self._persist_path = persist_path
self._serializer: Optional[BaseSerializer] = None
if self._persist_path is not None:
serializer_cls = SERIALIZER_MAP[serializer]
self._serializer = serializer_cls(persist_path=self._persist_path)
# data properties
self._embeddings: List[List[float]] = []
self._texts: List[str] = []
self._metadatas: List[dict] = []
self._ids: List[str] = []
# cache properties
self._embeddings_np: Any = np.asarray([])
if self._persist_path is not None and os.path.isfile(self._persist_path):
self._load()
@property
def embeddings(self) -> Embeddings:
return self._embedding_function
[docs] def persist(self) -> None:
if self._serializer is None:
raise SKLearnVectorStoreException(
"You must specify a persist_path on creation to persist the "
"collection."
)
data = {
"ids": self._ids,
"texts": self._texts,
"metadatas": self._metadatas,
"embeddings": self._embeddings,
}
self._serializer.save(data)
def _load(self) -> None:
if self._serializer is None:
raise SKLearnVectorStoreException(
"You must specify a persist_path on creation to load the " "collection."
)
data = self._serializer.load()
self._embeddings = data["embeddings"]
self._texts = data["texts"]
self._metadatas = data["metadatas"]
self._ids = data["ids"]
self._update_neighbors()
[docs] def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> List[str]:
_texts = list(texts)
_ids = ids or [str(uuid4()) for _ in _texts]
self._texts.extend(_texts)
self._embeddings.extend(self._embedding_function.embed_documents(_texts))
self._metadatas.extend(metadatas or ([{}] * len(_texts)))
self._ids.extend(_ids)
self._update_neighbors()
return _ids
def _update_neighbors(self) -> None:
if len(self._embeddings) == 0:
raise SKLearnVectorStoreException(
"No data was added to SKLearnVectorStore."
)
self._embeddings_np = self._np.asarray(self._embeddings)
self._neighbors.fit(self._embeddings_np)
self._neighbors_fitted = True
def _similarity_index_search_with_score(
self, query_embedding: List[float], *, k: int = DEFAULT_K, **kwargs: Any
) -> List[Tuple[int, float]]:
"""搜索与查询嵌入相似的k个嵌入。返回一个(index, distance)元组的列表。
"""
if not self._neighbors_fitted:
raise SKLearnVectorStoreException(
"No data was added to SKLearnVectorStore."
)
neigh_dists, neigh_idxs = self._neighbors.kneighbors(
[query_embedding], n_neighbors=k
)
return list(zip(neigh_idxs[0], neigh_dists[0]))
[docs] def similarity_search_with_score(
self, query: str, *, k: int = DEFAULT_K, **kwargs: Any
) -> List[Tuple[Document, float]]:
query_embedding = self._embedding_function.embed_query(query)
indices_dists = self._similarity_index_search_with_score(
query_embedding, k=k, **kwargs
)
return [
(
Document(
page_content=self._texts[idx],
metadata={"id": self._ids[idx], **self._metadatas[idx]},
),
dist,
)
for idx, dist in indices_dists
]
[docs] def similarity_search(
self, query: str, k: int = DEFAULT_K, **kwargs: Any
) -> List[Document]:
docs_scores = self.similarity_search_with_score(query, k=k, **kwargs)
return [doc for doc, _ in docs_scores]
def _similarity_search_with_relevance_scores(
self, query: str, k: int = DEFAULT_K, **kwargs: Any
) -> List[Tuple[Document, float]]:
docs_dists = self.similarity_search_with_score(query, k=k, **kwargs)
docs, dists = zip(*docs_dists)
scores = [1 / math.exp(dist) for dist in dists]
return list(zip(list(docs), scores))
[docs] def max_marginal_relevance_search_by_vector(
self,
embedding: List[float],
k: int = DEFAULT_K,
fetch_k: int = DEFAULT_FETCH_K,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> List[Document]:
"""返回使用最大边际相关性选择的文档。
最大边际相关性优化了与查询的相似性和所选文档之间的多样性。
参数:
embedding: 查找与之相似的文档的嵌入。
k: 要返回的文档数量。默认为4。
fetch_k: 要获取以传递给MMR算法的文档数量。
lambda_mult: 介于0和1之间的数字,确定结果之间多样性的程度,0对应最大多样性,1对应最小多样性。
默认为0.5。
返回:
通过最大边际相关性选择的文档列表。
"""
indices_dists = self._similarity_index_search_with_score(
embedding, k=fetch_k, **kwargs
)
indices, _ = zip(*indices_dists)
result_embeddings = self._embeddings_np[indices,]
mmr_selected = maximal_marginal_relevance(
self._np.array(embedding, dtype=self._np.float32),
result_embeddings,
k=k,
lambda_mult=lambda_mult,
)
mmr_indices = [indices[i] for i in mmr_selected]
return [
Document(
page_content=self._texts[idx],
metadata={"id": self._ids[idx], **self._metadatas[idx]},
)
for idx in mmr_indices
]
[docs] def max_marginal_relevance_search(
self,
query: str,
k: int = DEFAULT_K,
fetch_k: int = DEFAULT_FETCH_K,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> List[Document]:
"""返回使用最大边际相关性选择的文档。
最大边际相关性优化了与查询的相似性和所选文档之间的多样性。
参数:
query:要查找类似文档的文本。
k:要返回的文档数量。默认为4。
fetch_k:要获取的文档数量以传递给MMR算法。
lambda_mult:0到1之间的数字,确定结果之间多样性的程度,其中0对应最大多样性,1对应最小多样性。
默认为0.5。
返回:
由最大边际相关性选择的文档列表。
"""
if self._embedding_function is None:
raise ValueError(
"For MMR search, you must specify an embedding function on creation."
)
embedding = self._embedding_function.embed_query(query)
docs = self.max_marginal_relevance_search_by_vector(
embedding, k, fetch_k, lambda_mul=lambda_mult
)
return docs
[docs] @classmethod
def from_texts(
cls,
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
persist_path: Optional[str] = None,
**kwargs: Any,
) -> "SKLearnVectorStore":
vs = SKLearnVectorStore(embedding, persist_path=persist_path, **kwargs)
vs.add_texts(texts, metadatas=metadatas, ids=ids)
return vs