from __future__ import annotations
import uuid
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Tuple,
Type,
)
if TYPE_CHECKING:
import bagel
import bagel.config
from bagel.api.types import ID, OneOrMany, Where, WhereDocument
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.utils import xor_args
from langchain_core.vectorstores import VectorStore
DEFAULT_K = 5
def _results_to_docs(results: Any) -> List[Document]:
return [doc for doc, _ in _results_to_docs_and_scores(results)]
def _results_to_docs_and_scores(results: Any) -> List[Tuple[Document, float]]:
return [
(Document(page_content=result[0], metadata=result[1] or {}), result[2])
for result in zip(
results["documents"][0],
results["metadatas"][0],
results["distances"][0],
)
]
[docs]class Bagel(VectorStore):
"""``Bagel.net`` 推理平台。
要使用,您应该已安装``bagelML`` Python包。
示例:
.. code-block:: python
from langchain_community.vectorstores import Bagel
vectorstore = Bagel(cluster_name="langchain_store")
"""
_LANGCHAIN_DEFAULT_CLUSTER_NAME = "langchain"
[docs] def __init__(
self,
cluster_name: str = _LANGCHAIN_DEFAULT_CLUSTER_NAME,
client_settings: Optional[bagel.config.Settings] = None,
embedding_function: Optional[Embeddings] = None,
cluster_metadata: Optional[Dict] = None,
client: Optional[bagel.Client] = None,
relevance_score_fn: Optional[Callable[[float], float]] = None,
) -> None:
"""使用贝果客户端进行初始化"""
try:
import bagel
import bagel.config
except ImportError:
raise ImportError("Please install bagel `pip install bagelML`.")
if client is not None:
self._client_settings = client_settings
self._client = client
else:
if client_settings:
_client_settings = client_settings
else:
_client_settings = bagel.config.Settings(
bagel_api_impl="rest",
bagel_server_host="api.bageldb.ai",
)
self._client_settings = _client_settings
self._client = bagel.Client(_client_settings)
self._cluster = self._client.get_or_create_cluster(
name=cluster_name,
metadata=cluster_metadata,
)
self.override_relevance_score_fn = relevance_score_fn
self._embedding_function = embedding_function
@property
def embeddings(self) -> Optional[Embeddings]:
return self._embedding_function
@xor_args(("query_texts", "query_embeddings"))
def __query_cluster(
self,
query_texts: Optional[List[str]] = None,
query_embeddings: Optional[List[List[float]]] = None,
n_results: int = 4,
where: Optional[Dict[str, str]] = None,
**kwargs: Any,
) -> List[Document]:
"""根据提供的参数查询Bagel集群。"""
try:
import bagel # noqa: F401
except ImportError:
raise ImportError("Please install bagel `pip install bagelML`.")
if self._embedding_function and query_embeddings is None and query_texts:
texts = list(query_texts)
query_embeddings = self._embedding_function.embed_documents(texts)
query_texts = None
return self._cluster.find(
query_texts=query_texts,
query_embeddings=query_embeddings,
n_results=n_results,
where=where,
**kwargs,
)
[docs] def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
embeddings: Optional[List[List[float]]] = None,
**kwargs: Any,
) -> List[str]:
"""将文本与其对应的嵌入向量和可选元数据添加到Bagel集群中。
参数:
texts(Iterable[str]):要添加的文本。
embeddings(Optional[List[float]):嵌入向量的列表。
metadatas(Optional[List[dict]):元数据的可选列表。
ids(Optional[List[str]):文本的唯一ID列表。
返回:
List[str]:表示已添加文本的唯一ID列表。
"""
# creating unique ids if None
if ids is None:
ids = [str(uuid.uuid4()) for _ in texts]
texts = list(texts)
if self._embedding_function and embeddings is None and texts:
embeddings = self._embedding_function.embed_documents(texts)
if metadatas:
length_diff = len(texts) - len(metadatas)
if length_diff:
metadatas = metadatas + [{}] * length_diff
empty_ids = []
non_empty_ids = []
for idx, metadata in enumerate(metadatas):
if metadata:
non_empty_ids.append(idx)
else:
empty_ids.append(idx)
if non_empty_ids:
metadatas = [metadatas[idx] for idx in non_empty_ids]
texts_with_metadatas = [texts[idx] for idx in non_empty_ids]
embeddings_with_metadatas = (
[embeddings[idx] for idx in non_empty_ids] if embeddings else None
)
ids_with_metadata = [ids[idx] for idx in non_empty_ids]
self._cluster.upsert(
embeddings=embeddings_with_metadatas,
metadatas=metadatas,
documents=texts_with_metadatas,
ids=ids_with_metadata,
)
if empty_ids:
texts_without_metadatas = [texts[j] for j in empty_ids]
embeddings_without_metadatas = (
[embeddings[j] for j in empty_ids] if embeddings else None
)
ids_without_metadatas = [ids[j] for j in empty_ids]
self._cluster.upsert(
embeddings=embeddings_without_metadatas,
documents=texts_without_metadatas,
ids=ids_without_metadatas,
)
else:
metadatas = [{}] * len(texts)
self._cluster.upsert(
embeddings=embeddings,
documents=texts,
metadatas=metadatas,
ids=ids,
)
return ids
[docs] def similarity_search(
self,
query: str,
k: int = DEFAULT_K,
where: Optional[Dict[str, str]] = None,
**kwargs: Any,
) -> List[Document]:
"""使用Bagel运行相似性搜索。
参数:
query(str):要搜索相似文档/文本的查询文本。
k(int):要返回的结果数量。
where(Optional[Dict[str, str]]):用于缩小范围的元数据过滤器。
返回:
List[Document]:表示与查询文本最相似的文档对象列表。
"""
docs_and_scores = self.similarity_search_with_score(query, k, where=where)
return [doc for doc, _ in docs_and_scores]
[docs] def similarity_search_with_score(
self,
query: str,
k: int = DEFAULT_K,
where: Optional[Dict[str, str]] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""运行使用Bagel进行相似性搜索,并返回带有相应相似性分数的文档。
参数:
query (str): 要搜索相似文档的查询文本。
k (int): 要返回的结果数量。
where (Optional[Dict[str, str]]): 使用元数据进行过滤。
返回:
List[Tuple[Document, float]]: 元组列表,每个元组包含一个表示相似文档的Document对象和其对应的相似性分数。
"""
results = self.__query_cluster(query_texts=[query], n_results=k, where=where)
return _results_to_docs_and_scores(results)
[docs] @classmethod
def from_texts(
cls: Type[Bagel],
texts: List[str],
embedding: Optional[Embeddings] = None,
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
cluster_name: str = _LANGCHAIN_DEFAULT_CLUSTER_NAME,
client_settings: Optional[bagel.config.Settings] = None,
cluster_metadata: Optional[Dict] = None,
client: Optional[bagel.Client] = None,
text_embeddings: Optional[List[List[float]]] = None,
**kwargs: Any,
) -> Bagel:
"""从文本列表中创建并初始化一个 Bagel 实例。
参数:
texts (List[str]): 要添加的文本内容列表。
cluster_name (str): Bagel 集群的名称。
client_settings (Optional[bagel.config.Settings]): 客户端设置。
cluster_metadata (Optional[Dict]): 集群的元数据。
embeddings (Optional[Embeddings]): 嵌入列表。
metadatas (Optional[List[dict]]): 元数据列表。
ids (Optional[List[str]]): 唯一 ID 列表。默认为 None。
client (Optional[bagel.Client]): Bagel 客户端实例。
返回:
Bagel: Bagel 向量存储。
"""
bagel_cluster = cls(
cluster_name=cluster_name,
embedding_function=embedding,
client_settings=client_settings,
client=client,
cluster_metadata=cluster_metadata,
**kwargs,
)
_ = bagel_cluster.add_texts(
texts=texts, embeddings=text_embeddings, metadatas=metadatas, ids=ids
)
return bagel_cluster
[docs] def delete_cluster(self) -> None:
"""删除集群。"""
self._client.delete_cluster(self._cluster.name)
[docs] def similarity_search_by_vector_with_relevance_scores(
self,
query_embeddings: List[float],
k: int = DEFAULT_K,
where: Optional[Dict[str, str]] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""
返回与嵌入向量最相似的文档及相似度分数。
"""
results = self.__query_cluster(
query_embeddings=query_embeddings, n_results=k, where=where
)
return _results_to_docs_and_scores(results)
[docs] def similarity_search_by_vector(
self,
embedding: List[float],
k: int = DEFAULT_K,
where: Optional[Dict[str, str]] = None,
**kwargs: Any,
) -> List[Document]:
"""返回与嵌入向量最相似的文档。"""
results = self.__query_cluster(
query_embeddings=embedding, n_results=k, where=where
)
return _results_to_docs(results)
def _select_relevance_score_fn(self) -> Callable[[float], float]:
"""根据Bagel集群中使用的距离度量选择并返回适当的相关性评分函数。
"""
if self.override_relevance_score_fn:
return self.override_relevance_score_fn
distance = "l2"
distance_key = "hnsw:space"
metadata = self._cluster.metadata
if metadata and distance_key in metadata:
distance = metadata[distance_key]
if distance == "cosine":
return self._cosine_relevance_score_fn
elif distance == "l2":
return self._euclidean_relevance_score_fn
elif distance == "ip":
return self._max_inner_product_relevance_score_fn
else:
raise ValueError(
"No supported normalization function for distance"
f" metric of type: {distance}. Consider providing"
" relevance_score_fn to Bagel constructor."
)
[docs] @classmethod
def from_documents(
cls: Type[Bagel],
documents: List[Document],
embedding: Optional[Embeddings] = None,
ids: Optional[List[str]] = None,
cluster_name: str = _LANGCHAIN_DEFAULT_CLUSTER_NAME,
client_settings: Optional[bagel.config.Settings] = None,
client: Optional[bagel.Client] = None,
cluster_metadata: Optional[Dict] = None,
**kwargs: Any,
) -> Bagel:
"""创建一个Bagel向量存储,从一个文档列表中。
参数:
documents (List[Document]): 要添加到Bagel向量存储中的Document对象列表。
embedding (Optional[List[float]]): 嵌入列表。
ids (Optional[List[str]]): ID列表。默认为None。
cluster_name (str): Bagel集群的名称。
client_settings (Optional[bagel.config.Settings]): 客户端设置。
client (Optional[bagel.Client]): Bagel客户端实例。
cluster_metadata (Optional[Dict]): 与Bagel集群相关的元数据。默认为None。
返回:
Bagel: Bagel向量存储。
"""
texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
return cls.from_texts(
texts=texts,
embedding=embedding,
metadatas=metadatas,
ids=ids,
cluster_name=cluster_name,
client_settings=client_settings,
client=client,
cluster_metadata=cluster_metadata,
**kwargs,
)
[docs] def update_document(self, document_id: str, document: Document) -> None:
"""更新集群中的文档。
参数:
document_id(str):要更新的文档的ID。
document(Document):要更新的文档。
"""
text = document.page_content
metadata = document.metadata
self._cluster.update(
ids=[document_id],
documents=[text],
metadatas=[metadata],
)
[docs] def get(
self,
ids: Optional[OneOrMany[ID]] = None,
where: Optional[Where] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
where_document: Optional[WhereDocument] = None,
include: Optional[List[str]] = None,
) -> Dict[str, Any]:
"""获取集合。"""
kwargs = {
"ids": ids,
"where": where,
"limit": limit,
"offset": offset,
"where_document": where_document,
}
if include is not None:
kwargs["include"] = include
return self._cluster.get(**kwargs)
[docs] def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> None:
"""根据IDs删除。
参数:
ids:要删除的ID列表。
"""
self._cluster.delete(ids=ids)