Source code for langchain_community.vectorstores.rocksetdb

from __future__ import annotations

import logging
from copy import deepcopy
from enum import Enum
from typing import Any, Iterable, List, Optional, Tuple

import numpy as np
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.runnables import run_in_executor
from langchain_core.vectorstores import VectorStore

from langchain_community.vectorstores.utils import maximal_marginal_relevance

logger = logging.getLogger(__name__)


[docs]class Rockset(VectorStore): """`Rockset` 向量存储。 要使用,您应该已安装 `rockset` python 包。请注意,要使用此功能,所使用的集合必须已经存在于您的 Rockset 实例中。 您还必须确保使用 Rockset 导入转换来在用于存储集合中的 `embedding_key` 的列上应用 `VECTOR_ENFORCE`。 请参阅:https://rockset.com/blog/introducing-vector-search-on-rockset/ 了解更多详细信息 以下所有内容都假定使用 `commons` Rockset 工作区。 示例: .. code-block:: python from langchain_community.vectorstores import Rockset from langchain_community.embeddings.openai import OpenAIEmbeddings import rockset # 确保使用正确的主机(区域)来连接您的 Rockset 实例 # 并且 APIKEY 具有对集合的读写访问权限。 rs = rockset.RocksetClient(host=rockset.Regions.use1a1, api_key="***") collection_name = "langchain_demo" embeddings = OpenAIEmbeddings() vectorstore = Rockset(rs, collection_name, embeddings, "description", "description_embedding")"""
[docs] def __init__( self, client: Any, embeddings: Embeddings, collection_name: str, text_key: str, embedding_key: str, workspace: str = "commons", ): """使用Rockset客户端进行初始化。 参数: client: Rockset客户端对象 collection: Rockset集合,用于插入文档/查询 embeddings: Langchain Embeddings对象,用于生成给定文本的嵌入 text_key: 用于存储文本的Rockset集合中的列 embedding_key: 用于存储嵌入的Rockset集合中的列。 注意:我们必须通过Rockset摄取转换在此列上应用`VECTOR_ENFORCE()`。 """ try: from rockset import RocksetClient except ImportError: raise ImportError( "Could not import rockset client python package. " "Please install it with `pip install rockset`." ) if not isinstance(client, RocksetClient): raise ValueError( f"client should be an instance of rockset.RocksetClient, " f"got {type(client)}" ) # TODO: check that `collection_name` exists in rockset. Create if not. self._client = client self._collection_name = collection_name self._embeddings = embeddings self._text_key = text_key self._embedding_key = embedding_key self._workspace = workspace try: self._client.set_application("langchain") except AttributeError: # ignore pass
@property def embeddings(self) -> Embeddings: return self._embeddings
[docs] def add_texts( self, texts: Iterable[str], metadatas: Optional[List[dict]] = None, ids: Optional[List[str]] = None, batch_size: int = 32, **kwargs: Any, ) -> List[str]: """运行更多文本通过嵌入并添加到向量存储库 参数: texts:要添加到向量存储库的字符串的可迭代对象。 metadatas:与文本相关联的元数据的可选列表。 ids:与文本关联的可选id列表。 batch_size:将文档分批发送到rockset。 返回: 将文本添加到向量存储库中的id列表。 """ batch: list[dict] = [] stored_ids = [] for i, text in enumerate(texts): if len(batch) == batch_size: stored_ids += self._write_documents_to_rockset(batch) batch = [] doc = {} if metadatas and len(metadatas) > i: doc = deepcopy(metadatas[i]) if ids and len(ids) > i: doc["_id"] = ids[i] doc[self._text_key] = text doc[self._embedding_key] = self._embeddings.embed_query(text) batch.append(doc) if len(batch) > 0: stored_ids += self._write_documents_to_rockset(batch) batch = [] return stored_ids
[docs] @classmethod def from_texts( cls, texts: List[str], embedding: Embeddings, metadatas: Optional[List[dict]] = None, client: Any = None, collection_name: str = "", text_key: str = "", embedding_key: str = "", ids: Optional[List[str]] = None, batch_size: int = 32, **kwargs: Any, ) -> Rockset: """使用现有文本创建Rockset包装器。 这旨在作为一个更快的入门方式。 """ # Sanitize inputs assert client is not None, "Rockset Client cannot be None" assert collection_name, "Collection name cannot be empty" assert text_key, "Text key name cannot be empty" assert embedding_key, "Embedding key cannot be empty" rockset = cls(client, embedding, collection_name, text_key, embedding_key) rockset.add_texts(texts, metadatas, ids, batch_size) return rockset
# Rockset supports these vector distance functions. class DistanceFunction(Enum): COSINE_SIM = "COSINE_SIM" EUCLIDEAN_DIST = "EUCLIDEAN_DIST" DOT_PRODUCT = "DOT_PRODUCT" # how to sort results for "similarity" def order_by(self) -> str: if self.value == "EUCLIDEAN_DIST": return "ASC" return "DESC"
[docs] def similarity_search_with_relevance_scores( self, query: str, k: int = 4, distance_func: DistanceFunction = DistanceFunction.COSINE_SIM, where_str: Optional[str] = None, **kwargs: Any, ) -> List[Tuple[Document, float]]: """使用Rockset执行相似性搜索 参数: query (str): 要查找与之相似文档的文本。 distance_func (DistanceFunction): 在Rockset中计算两个向量之间距离的方法。 k (int, optional): 要检索的前K个相邻项。默认为4。 where_str (Optional[str], optional): 作为SQL“where”条件字符串提供的元数据过滤器。默认为None。 例如 "price<=70.0 AND brand='Nintendo'" 注意: 请不要让最终用户填写此内容,并始终注意SQL注入。 返回: List[Tuple[Document, float]]: 具有其相关性分数的文档列表 """ return self.similarity_search_by_vector_with_relevance_scores( self._embeddings.embed_query(query), k, distance_func, where_str, **kwargs, )
[docs] def similarity_search_by_vector( self, embedding: List[float], k: int = 4, distance_func: DistanceFunction = DistanceFunction.COSINE_SIM, where_str: Optional[str] = None, **kwargs: Any, ) -> List[Document]: """接受一个查询嵌入(向量),并返回具有相似嵌入的文档。 """ docs_and_scores = self.similarity_search_by_vector_with_relevance_scores( embedding, k, distance_func, where_str, **kwargs ) return [doc for doc, _ in docs_and_scores]
[docs] def similarity_search_by_vector_with_relevance_scores( self, embedding: List[float], k: int = 4, distance_func: DistanceFunction = DistanceFunction.COSINE_SIM, where_str: Optional[str] = None, **kwargs: Any, ) -> List[Tuple[Document, float]]: """接受一个查询嵌入(向量),并返回具有相似嵌入的文档以及它们的相关性分数。 """ exclude_embeddings = True if "exclude_embeddings" in kwargs: exclude_embeddings = kwargs["exclude_embeddings"] q_str = self._build_query_sql( embedding, distance_func, k, where_str, exclude_embeddings ) try: query_response = self._client.Queries.query(sql={"query": q_str}) except Exception as e: logger.error("Exception when querying Rockset: %s\n", e) return [] finalResult: list[Tuple[Document, float]] = [] for document in query_response.results: metadata = {} assert isinstance( document, dict ), "document should be of type `dict[str,Any]`. But found: `{}`".format( type(document) ) for k, v in document.items(): if k == self._text_key: assert isinstance(v, str), ( "page content stored in column `{}` must be of type `str`. " "But found: `{}`" ).format(self._text_key, type(v)) page_content = v elif k == "dist": assert isinstance(v, float), ( "Computed distance between vectors must of type `float`. " "But found {}" ).format(type(v)) score = v elif k not in ["_id", "_event_time", "_meta"]: # These columns are populated by Rockset when documents are # inserted. No need to return them in metadata dict. metadata[k] = v finalResult.append( (Document(page_content=page_content, metadata=metadata), score) ) return finalResult
# Helper functions def _build_query_sql( self, query_embedding: List[float], distance_func: DistanceFunction, k: int = 4, where_str: Optional[str] = None, exclude_embeddings: bool = True, ) -> str: """构建Rockset SQL查询,以查询与查询向量相似的向量""" q_embedding_str = ",".join(map(str, query_embedding)) distance_str = f"""{distance_func.value}({self._embedding_key}, \ [{q_embedding_str}]) as dist""" where_str = f"WHERE {where_str}\n" if where_str else "" select_embedding = ( f" EXCEPT({self._embedding_key})," if exclude_embeddings else "," ) return f"""\ SELECT *{select_embedding} {distance_str} FROM {self._workspace}.{self._collection_name} {where_str}\ ORDER BY dist {distance_func.order_by()} LIMIT {str(k)} """ def _write_documents_to_rockset(self, batch: List[dict]) -> List[str]: add_doc_res = self._client.Documents.add_documents( collection=self._collection_name, data=batch, workspace=self._workspace ) return [doc_status._id for doc_status in add_doc_res.data]
[docs] def delete_texts(self, ids: List[str]) -> None: """从Rockset集合中删除一个文档列表""" try: from rockset.models import DeleteDocumentsRequestData except ImportError: raise ImportError( "Could not import rockset client python package. " "Please install it with `pip install rockset`." ) self._client.Documents.delete_documents( collection=self._collection_name, data=[DeleteDocumentsRequestData(id=i) for i in ids], workspace=self._workspace, )
[docs] def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]: try: if ids is None: ids = [] self.delete_texts(ids) except Exception as e: logger.error("Exception when deleting docs from Rockset: %s\n", e) return False return True
[docs] async def adelete( self, ids: Optional[List[str]] = None, **kwargs: Any ) -> Optional[bool]: return await run_in_executor(None, self.delete, ids, **kwargs)