from __future__ import annotations
import logging
from copy import deepcopy
from enum import Enum
from typing import Any, Iterable, List, Optional, Tuple
import numpy as np
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.runnables import run_in_executor
from langchain_core.vectorstores import VectorStore
from langchain_community.vectorstores.utils import maximal_marginal_relevance
logger = logging.getLogger(__name__)
[docs]class Rockset(VectorStore):
"""`Rockset` 向量存储。
要使用,您应该已安装 `rockset` python 包。请注意,要使用此功能,所使用的集合必须已经存在于您的 Rockset 实例中。
您还必须确保使用 Rockset 导入转换来在用于存储集合中的 `embedding_key` 的列上应用 `VECTOR_ENFORCE`。
请参阅:https://rockset.com/blog/introducing-vector-search-on-rockset/ 了解更多详细信息
以下所有内容都假定使用 `commons` Rockset 工作区。
示例:
.. code-block:: python
from langchain_community.vectorstores import Rockset
from langchain_community.embeddings.openai import OpenAIEmbeddings
import rockset
# 确保使用正确的主机(区域)来连接您的 Rockset 实例
# 并且 APIKEY 具有对集合的读写访问权限。
rs = rockset.RocksetClient(host=rockset.Regions.use1a1, api_key="***")
collection_name = "langchain_demo"
embeddings = OpenAIEmbeddings()
vectorstore = Rockset(rs, collection_name, embeddings,
"description", "description_embedding")"""
[docs] def __init__(
self,
client: Any,
embeddings: Embeddings,
collection_name: str,
text_key: str,
embedding_key: str,
workspace: str = "commons",
):
"""使用Rockset客户端进行初始化。
参数:
client: Rockset客户端对象
collection: Rockset集合,用于插入文档/查询
embeddings: Langchain Embeddings对象,用于生成给定文本的嵌入
text_key: 用于存储文本的Rockset集合中的列
embedding_key: 用于存储嵌入的Rockset集合中的列。
注意:我们必须通过Rockset摄取转换在此列上应用`VECTOR_ENFORCE()`。
"""
try:
from rockset import RocksetClient
except ImportError:
raise ImportError(
"Could not import rockset client python package. "
"Please install it with `pip install rockset`."
)
if not isinstance(client, RocksetClient):
raise ValueError(
f"client should be an instance of rockset.RocksetClient, "
f"got {type(client)}"
)
# TODO: check that `collection_name` exists in rockset. Create if not.
self._client = client
self._collection_name = collection_name
self._embeddings = embeddings
self._text_key = text_key
self._embedding_key = embedding_key
self._workspace = workspace
try:
self._client.set_application("langchain")
except AttributeError:
# ignore
pass
@property
def embeddings(self) -> Embeddings:
return self._embeddings
[docs] def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
batch_size: int = 32,
**kwargs: Any,
) -> List[str]:
"""运行更多文本通过嵌入并添加到向量存储库
参数:
texts:要添加到向量存储库的字符串的可迭代对象。
metadatas:与文本相关联的元数据的可选列表。
ids:与文本关联的可选id列表。
batch_size:将文档分批发送到rockset。
返回:
将文本添加到向量存储库中的id列表。
"""
batch: list[dict] = []
stored_ids = []
for i, text in enumerate(texts):
if len(batch) == batch_size:
stored_ids += self._write_documents_to_rockset(batch)
batch = []
doc = {}
if metadatas and len(metadatas) > i:
doc = deepcopy(metadatas[i])
if ids and len(ids) > i:
doc["_id"] = ids[i]
doc[self._text_key] = text
doc[self._embedding_key] = self._embeddings.embed_query(text)
batch.append(doc)
if len(batch) > 0:
stored_ids += self._write_documents_to_rockset(batch)
batch = []
return stored_ids
[docs] @classmethod
def from_texts(
cls,
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
client: Any = None,
collection_name: str = "",
text_key: str = "",
embedding_key: str = "",
ids: Optional[List[str]] = None,
batch_size: int = 32,
**kwargs: Any,
) -> Rockset:
"""使用现有文本创建Rockset包装器。
这旨在作为一个更快的入门方式。
"""
# Sanitize inputs
assert client is not None, "Rockset Client cannot be None"
assert collection_name, "Collection name cannot be empty"
assert text_key, "Text key name cannot be empty"
assert embedding_key, "Embedding key cannot be empty"
rockset = cls(client, embedding, collection_name, text_key, embedding_key)
rockset.add_texts(texts, metadatas, ids, batch_size)
return rockset
# Rockset supports these vector distance functions.
class DistanceFunction(Enum):
COSINE_SIM = "COSINE_SIM"
EUCLIDEAN_DIST = "EUCLIDEAN_DIST"
DOT_PRODUCT = "DOT_PRODUCT"
# how to sort results for "similarity"
def order_by(self) -> str:
if self.value == "EUCLIDEAN_DIST":
return "ASC"
return "DESC"
[docs] def similarity_search_with_relevance_scores(
self,
query: str,
k: int = 4,
distance_func: DistanceFunction = DistanceFunction.COSINE_SIM,
where_str: Optional[str] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""使用Rockset执行相似性搜索
参数:
query (str): 要查找与之相似文档的文本。
distance_func (DistanceFunction): 在Rockset中计算两个向量之间距离的方法。
k (int, optional): 要检索的前K个相邻项。默认为4。
where_str (Optional[str], optional): 作为SQL“where”条件字符串提供的元数据过滤器。默认为None。
例如 "price<=70.0 AND brand='Nintendo'"
注意: 请不要让最终用户填写此内容,并始终注意SQL注入。
返回:
List[Tuple[Document, float]]: 具有其相关性分数的文档列表
"""
return self.similarity_search_by_vector_with_relevance_scores(
self._embeddings.embed_query(query),
k,
distance_func,
where_str,
**kwargs,
)
[docs] def similarity_search(
self,
query: str,
k: int = 4,
distance_func: DistanceFunction = DistanceFunction.COSINE_SIM,
where_str: Optional[str] = None,
**kwargs: Any,
) -> List[Document]:
"""与`similarity_search_with_relevance_scores`相同,但不返回分数。
"""
return self.similarity_search_by_vector(
self._embeddings.embed_query(query),
k,
distance_func,
where_str,
**kwargs,
)
[docs] def similarity_search_by_vector(
self,
embedding: List[float],
k: int = 4,
distance_func: DistanceFunction = DistanceFunction.COSINE_SIM,
where_str: Optional[str] = None,
**kwargs: Any,
) -> List[Document]:
"""接受一个查询嵌入(向量),并返回具有相似嵌入的文档。
"""
docs_and_scores = self.similarity_search_by_vector_with_relevance_scores(
embedding, k, distance_func, where_str, **kwargs
)
return [doc for doc, _ in docs_and_scores]
[docs] def similarity_search_by_vector_with_relevance_scores(
self,
embedding: List[float],
k: int = 4,
distance_func: DistanceFunction = DistanceFunction.COSINE_SIM,
where_str: Optional[str] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""接受一个查询嵌入(向量),并返回具有相似嵌入的文档以及它们的相关性分数。
"""
exclude_embeddings = True
if "exclude_embeddings" in kwargs:
exclude_embeddings = kwargs["exclude_embeddings"]
q_str = self._build_query_sql(
embedding, distance_func, k, where_str, exclude_embeddings
)
try:
query_response = self._client.Queries.query(sql={"query": q_str})
except Exception as e:
logger.error("Exception when querying Rockset: %s\n", e)
return []
finalResult: list[Tuple[Document, float]] = []
for document in query_response.results:
metadata = {}
assert isinstance(
document, dict
), "document should be of type `dict[str,Any]`. But found: `{}`".format(
type(document)
)
for k, v in document.items():
if k == self._text_key:
assert isinstance(v, str), (
"page content stored in column `{}` must be of type `str`. "
"But found: `{}`"
).format(self._text_key, type(v))
page_content = v
elif k == "dist":
assert isinstance(v, float), (
"Computed distance between vectors must of type `float`. "
"But found {}"
).format(type(v))
score = v
elif k not in ["_id", "_event_time", "_meta"]:
# These columns are populated by Rockset when documents are
# inserted. No need to return them in metadata dict.
metadata[k] = v
finalResult.append(
(Document(page_content=page_content, metadata=metadata), score)
)
return finalResult
[docs] def max_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
*,
where_str: Optional[str] = None,
**kwargs: Any,
) -> List[Document]:
"""返回使用最大边际相关性选择的文档。
最大边际相关性优化了与查询的相似性和所选文档之间的多样性。
参数:
query: 要查找类似文档的文本。
k: 要返回的文档数量。默认为4。
fetch_k: 要获取以传递给MMR算法的文档数量。
distance_func(DistanceFunction):如何计算Rockset中两个向量之间的距离。
lambda_mult: 0到1之间的数字,确定结果之间多样性的程度,0对应最大多样性,1对应最小多样性。默认为0.5。
where_str: sql查询的where子句
返回:
通过最大边际相关性选择的文档列表。
"""
query_embedding = self._embeddings.embed_query(query)
initial_docs = self.similarity_search_by_vector(
query_embedding,
k=fetch_k,
where_str=where_str,
exclude_embeddings=False,
**kwargs,
)
embeddings = [doc.metadata[self._embedding_key] for doc in initial_docs]
selected_indices = maximal_marginal_relevance(
np.array(query_embedding),
embeddings,
lambda_mult=lambda_mult,
k=k,
)
# remove embeddings key before returning for cleanup to be consistent with
# other search functions
for i in selected_indices:
del initial_docs[i].metadata[self._embedding_key]
return [initial_docs[i] for i in selected_indices]
# Helper functions
def _build_query_sql(
self,
query_embedding: List[float],
distance_func: DistanceFunction,
k: int = 4,
where_str: Optional[str] = None,
exclude_embeddings: bool = True,
) -> str:
"""构建Rockset SQL查询,以查询与查询向量相似的向量"""
q_embedding_str = ",".join(map(str, query_embedding))
distance_str = f"""{distance_func.value}({self._embedding_key}, \
[{q_embedding_str}]) as dist"""
where_str = f"WHERE {where_str}\n" if where_str else ""
select_embedding = (
f" EXCEPT({self._embedding_key})," if exclude_embeddings else ","
)
return f"""\
SELECT *{select_embedding} {distance_str}
FROM {self._workspace}.{self._collection_name}
{where_str}\
ORDER BY dist {distance_func.order_by()}
LIMIT {str(k)}
"""
def _write_documents_to_rockset(self, batch: List[dict]) -> List[str]:
add_doc_res = self._client.Documents.add_documents(
collection=self._collection_name, data=batch, workspace=self._workspace
)
return [doc_status._id for doc_status in add_doc_res.data]
[docs] def delete_texts(self, ids: List[str]) -> None:
"""从Rockset集合中删除一个文档列表"""
try:
from rockset.models import DeleteDocumentsRequestData
except ImportError:
raise ImportError(
"Could not import rockset client python package. "
"Please install it with `pip install rockset`."
)
self._client.Documents.delete_documents(
collection=self._collection_name,
data=[DeleteDocumentsRequestData(id=i) for i in ids],
workspace=self._workspace,
)
[docs] def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
try:
if ids is None:
ids = []
self.delete_texts(ids)
except Exception as e:
logger.error("Exception when deleting docs from Rockset: %s\n", e)
return False
return True
[docs] async def adelete(
self, ids: Optional[List[str]] = None, **kwargs: Any
) -> Optional[bool]:
return await run_in_executor(None, self.delete, ids, **kwargs)