from __future__ import annotations
import json
import logging
import uuid
from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Optional, Tuple, Type
import numpy as np
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VST, VectorStore
from langchain_community.vectorstores.utils import maximal_marginal_relevance
if TYPE_CHECKING:
from databricks.vector_search.client import VectorSearchIndex
logger = logging.getLogger(__name__)
[docs]class DatabricksVectorSearch(VectorStore):
"""`Databricks Vector Search` 向量存储。
要使用,应安装 ``databricks-vectorsearch`` python 包。
示例:
.. code-block:: python
from langchain_community.vectorstores import DatabricksVectorSearch
from databricks.vector_search.client import VectorSearchClient
vs_client = VectorSearchClient()
vs_index = vs_client.get_index(
endpoint_name="vs_endpoint",
index_name="ml.llm.index"
)
vectorstore = DatabricksVectorSearch(vs_index)
参数:
index: 一个 Databricks Vector Search 索引对象。
embedding: 嵌入模型。
对于直接访问索引或具有自管理嵌入的增量同步索引,需要。
text_column: 用于嵌入的文本列的名称。
对于直接访问索引或具有自管理嵌入的增量同步索引,需要。
确保指定的文本列在索引中。
columns: 在进行搜索时要获取的列名列表。
默认为 ``[primary_key, text_column]``。
具有由 Databricks 管理的嵌入的增量同步索引会为您管理摄入、删除和嵌入。
不支持手动摄入/删除文档/文本以用于增量同步索引。
如果要使用具有自管理嵌入的增量同步索引,需要提供嵌入模型和要用于嵌入的文本列名称。
示例:
.. code-block:: python
from langchain_community.vectorstores import DatabricksVectorSearch
from databricks.vector_search.client import VectorSearchClient
from langchain_community.embeddings.openai import OpenAIEmbeddings
vs_client = VectorSearchClient()
vs_index = vs_client.get_index(
endpoint_name="vs_endpoint",
index_name="ml.llm.index"
)
vectorstore = DatabricksVectorSearch(
index=vs_index,
embedding=OpenAIEmbeddings(),
text_column="document_content"
)
如果要自行管理文档的摄入/删除,可以使用直接访问索引。
示例:
.. code-block:: python
from langchain_community.vectorstores import DatabricksVectorSearch
from databricks.vector_search.client import VectorSearchClient
from langchain_community.embeddings.openai import OpenAIEmbeddings
vs_client = VectorSearchClient()
vs_index = vs_client.get_index(
endpoint_name="vs_endpoint",
index_name="ml.llm.index"
)
vectorstore = DatabricksVectorSearch(
index=vs_index,
embedding=OpenAIEmbeddings(),
text_column="document_content"
)
vectorstore.add_texts(
texts=["text1", "text2"]
)
有关 Databricks Vector Search 的更多信息,请参阅 `Databricks Vector Search
文档: https://docs.databricks.com/en/generative-ai/vector-search.html."""
[docs] def __init__(
self,
index: VectorSearchIndex,
*,
embedding: Optional[Embeddings] = None,
text_column: Optional[str] = None,
columns: Optional[List[str]] = None,
):
try:
from databricks.vector_search.client import VectorSearchIndex
except ImportError as e:
raise ImportError(
"Could not import databricks-vectorsearch python package. "
"Please install it with `pip install databricks-vectorsearch`."
) from e
# index
self.index = index
if not isinstance(index, VectorSearchIndex):
raise TypeError("index must be of type VectorSearchIndex.")
# index_details
index_details = self.index.describe()
self.primary_key = index_details["primary_key"]
self.index_type = index_details.get("index_type")
self._delta_sync_index_spec = index_details.get("delta_sync_index_spec", dict())
self._direct_access_index_spec = index_details.get(
"direct_access_index_spec", dict()
)
# text_column
if self._is_databricks_managed_embeddings():
index_source_column = self._embedding_source_column_name()
# check if input text column matches the source column of the index
if text_column is not None and text_column != index_source_column:
raise ValueError(
f"text_column '{text_column}' does not match with the "
f"source column of the index: '{index_source_column}'."
)
self.text_column = index_source_column
else:
self._require_arg(text_column, "text_column")
self.text_column = text_column
# columns
self.columns = columns or []
# add primary key column and source column if not in columns
if self.primary_key not in self.columns:
self.columns.append(self.primary_key)
if self.text_column and self.text_column not in self.columns:
self.columns.append(self.text_column)
# Validate specified columns are in the index
if self._is_direct_access_index():
index_schema = self._index_schema()
if index_schema:
for col in self.columns:
if col not in index_schema:
raise ValueError(
f"column '{col}' is not in the index's schema."
)
# embedding model
if not self._is_databricks_managed_embeddings():
# embedding model is required for direct-access index
# or delta-sync index with self-managed embedding
self._require_arg(embedding, "embedding")
self._embedding = embedding
# validate dimension matches
index_embedding_dimension = self._embedding_vector_column_dimension()
if index_embedding_dimension is not None:
inferred_embedding_dimension = self._infer_embedding_dimension()
if inferred_embedding_dimension != index_embedding_dimension:
raise ValueError(
f"embedding model's dimension '{inferred_embedding_dimension}' "
f"does not match with the index's dimension "
f"'{index_embedding_dimension}'."
)
else:
if embedding is not None:
logger.warning(
"embedding model is not used in delta-sync index with "
"Databricks-managed embeddings."
)
self._embedding = None
[docs] @classmethod
def from_texts(
cls: Type[VST],
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> VST:
raise NotImplementedError(
"`from_texts` is not supported. "
"Use `add_texts` to add to existing direct-access index."
)
[docs] def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[Any]] = None,
**kwargs: Any,
) -> List[str]:
"""将文本添加到索引中。
仅支持直接访问索引。
参数:
texts:要添加的文本列表。
metadatas:每个文本的元数据列表。默认为None。
ids:每个文本的id列表。默认为None。
如果未提供,则将为每个文本生成一个随机uuid。
返回:
将文本添加到索引后的id列表。
"""
self._op_require_direct_access_index("add_texts")
assert self.embeddings is not None, "embedding model is required."
# Wrap to list if input texts is a single string
if isinstance(texts, str):
texts = [texts]
texts = list(texts)
vectors = self.embeddings.embed_documents(texts)
ids = ids or [str(uuid.uuid4()) for _ in texts]
metadatas = metadatas or [{} for _ in texts]
updates = [
{
self.primary_key: id_,
self.text_column: text,
self._embedding_vector_column_name(): vector,
**metadata,
}
for text, vector, id_, metadata in zip(texts, vectors, ids, metadatas)
]
upsert_resp = self.index.upsert(updates)
if upsert_resp.get("status") in ("PARTIAL_SUCCESS", "FAILURE"):
failed_ids = upsert_resp.get("result", dict()).get(
"failed_primary_keys", []
)
if upsert_resp.get("status") == "FAILURE":
logger.error("Failed to add texts to the index.")
else:
logger.warning("Some texts failed to be added to the index.")
return [id_ for id_ in ids if id_ not in failed_ids]
return ids
@property
def embeddings(self) -> Optional[Embeddings]:
"""如果可用,访问查询嵌入对象。"""
return self._embedding
[docs] def delete(self, ids: Optional[List[Any]] = None, **kwargs: Any) -> Optional[bool]:
"""从索引中删除文档。
仅支持直接访问索引。
参数:
ids: 要删除的文档的id列表。
返回:
如果成功则返回True。
"""
self._op_require_direct_access_index("delete")
if ids is None:
raise ValueError("ids must be provided.")
self.index.delete(ids)
return True
[docs] def similarity_search(
self, query: str, k: int = 4, filters: Optional[Any] = None, **kwargs: Any
) -> List[Document]:
"""返回与查询最相似的文档。
参数:
query:要查找类似文档的文本。
k:要返回的文档数量。默认为4。
filters:要应用于查询的过滤器。默认为None。
返回:
与嵌入最相似的文档列表。
"""
docs_with_score = self.similarity_search_with_score(
query=query, k=k, filters=filters, **kwargs
)
return [doc for doc, _ in docs_with_score]
[docs] def similarity_search_with_score(
self, query: str, k: int = 4, filters: Optional[Any] = None, **kwargs: Any
) -> List[Tuple[Document, float]]:
"""返回与查询最相似的文档,以及分数。
参数:
query:要查找类似文档的文本。
k:要返回的文档数量。默认为4。
filters:要应用于查询的过滤器。默认为None。
返回:
返回与嵌入最相似的文档列表,以及每个文档的分数。
"""
if self._is_databricks_managed_embeddings():
query_text = query
query_vector = None
else:
assert self.embeddings is not None, "embedding model is required."
query_text = None
query_vector = self.embeddings.embed_query(query)
search_resp = self.index.similarity_search(
columns=self.columns,
query_text=query_text,
query_vector=query_vector,
filters=filters,
num_results=k,
)
return self._parse_search_response(search_resp)
@staticmethod
def _identity_fn(score: float) -> float:
return score
def _select_relevance_score_fn(self) -> Callable[[float], float]:
"""Databricks矢量搜索使用标准化得分1/(1+d),其中d是L2距离。因此,我们简单地返回恒等函数。
"""
return self._identity_fn
[docs] def max_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filters: Optional[Any] = None,
**kwargs: Any,
) -> List[Document]:
"""返回使用最大边际相关性选择的文档。
最大边际相关性优化了与查询的相似性和所选文档之间的多样性。
参数:
query:要查找类似文档的文本。
k:要返回的文档数量。默认为4。
fetch_k:要获取以传递给MMR算法的文档数量。
lambda_mult:0到1之间的数字,确定结果之间多样性的程度,0对应最大多样性,1对应最小多样性。默认为0.5。
filters:要应用于查询的过滤器。默认为无。
返回:
由最大边际相关性选择的文档列表。
"""
if not self._is_databricks_managed_embeddings():
assert self.embeddings is not None, "embedding model is required."
query_vector = self.embeddings.embed_query(query)
else:
raise ValueError(
"`max_marginal_relevance_search` is not supported for index with "
"Databricks-managed embeddings."
)
docs = self.max_marginal_relevance_search_by_vector(
query_vector,
k,
fetch_k,
lambda_mult=lambda_mult,
filters=filters,
)
return docs
[docs] def max_marginal_relevance_search_by_vector(
self,
embedding: List[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filters: Optional[Any] = None,
**kwargs: Any,
) -> List[Document]:
"""返回使用最大边际相关性选择的文档。
最大边际相关性优化了与查询的相似性和所选文档之间的多样性。
参数:
embedding: 用于查找类似文档的嵌入。
k: 要返回的文档数量。默认为4。
fetch_k: 要获取的文档数量以传递给MMR算法。
lambda_mult: 0到1之间的数字,确定结果之间多样性的程度,其中0对应最大多样性,1对应最小多样性。默认为0.5。
filters: 要应用于查询的过滤器。默认为无。
返回:
通过最大边际相关性选择的文档列表。
"""
if not self._is_databricks_managed_embeddings():
embedding_column = self._embedding_vector_column_name()
else:
raise ValueError(
"`max_marginal_relevance_search` is not supported for index with "
"Databricks-managed embeddings."
)
search_resp = self.index.similarity_search(
columns=list(set(self.columns + [embedding_column])),
query_text=None,
query_vector=embedding,
filters=filters,
num_results=fetch_k,
)
embeddings_result_index = (
search_resp.get("manifest").get("columns").index({"name": embedding_column})
)
embeddings = [
doc[embeddings_result_index]
for doc in search_resp.get("result").get("data_array")
]
mmr_selected = maximal_marginal_relevance(
np.array(embedding, dtype=np.float32),
embeddings,
k=k,
lambda_mult=lambda_mult,
)
ignore_cols: List = (
[embedding_column] if embedding_column not in self.columns else []
)
candidates = self._parse_search_response(search_resp, ignore_cols=ignore_cols)
selected_results = [r[0] for i, r in enumerate(candidates) if i in mmr_selected]
return selected_results
[docs] def similarity_search_by_vector(
self,
embedding: List[float],
k: int = 4,
filters: Optional[Any] = None,
**kwargs: Any,
) -> List[Document]:
"""返回与嵌入向量最相似的文档。
参数:
embedding: 要查找相似文档的嵌入。
k: 要返回的文档数量。默认为4。
filters: 要应用于查询的过滤器。默认为None。
返回:
与嵌入最相似的文档列表。
"""
docs_with_score = self.similarity_search_by_vector_with_score(
embedding=embedding, k=k, filters=filters, **kwargs
)
return [doc for doc, _ in docs_with_score]
[docs] def similarity_search_by_vector_with_score(
self,
embedding: List[float],
k: int = 4,
filters: Optional[Any] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""返回与嵌入向量最相似的文档,以及相似度分数。
参数:
embedding:要查找相似文档的嵌入。
k:要返回的文档数量。默认为4。
filters:要应用于查询的过滤器。默认为None。
返回:
返回与嵌入最相似的文档列表,以及每个文档的分数。
"""
if self._is_databricks_managed_embeddings():
raise ValueError(
"`similarity_search_by_vector` is not supported for index with "
"Databricks-managed embeddings."
)
search_resp = self.index.similarity_search(
columns=self.columns,
query_vector=embedding,
filters=filters,
num_results=k,
)
return self._parse_search_response(search_resp)
def _parse_search_response(
self, search_resp: dict, ignore_cols: Optional[List[str]] = None
) -> List[Tuple[Document, float]]:
"""将搜索响应解析为带有分数的文档列表。"""
if ignore_cols is None:
ignore_cols = []
columns = [
col["name"]
for col in search_resp.get("manifest", dict()).get("columns", [])
]
docs_with_score = []
for result in search_resp.get("result", dict()).get("data_array", []):
doc_id = result[columns.index(self.primary_key)]
text_content = result[columns.index(self.text_column)]
metadata = {
col: value
for col, value in zip(columns[:-1], result[:-1])
if col not in ([self.primary_key, self.text_column] + ignore_cols)
}
metadata[self.primary_key] = doc_id
score = result[-1]
doc = Document(page_content=text_content, metadata=metadata)
docs_with_score.append((doc, score))
return docs_with_score
def _index_schema(self) -> Optional[dict]:
"""返回索引模式作为字典。
如果找不到模式,则返回None。
"""
if self._is_direct_access_index():
schema_json = self._direct_access_index_spec.get("schema_json")
if schema_json is not None:
return json.loads(schema_json)
return None
def _embedding_vector_column_name(self) -> Optional[str]:
"""返回嵌入向量列的名称。
如果索引不是自管理的嵌入索引,则返回None。
"""
return self._embedding_vector_column().get("name")
def _embedding_vector_column_dimension(self) -> Optional[int]:
"""返回嵌入向量列的维度。
如果索引不是自管理的嵌入索引,则返回None。
"""
return self._embedding_vector_column().get("embedding_dimension")
def _embedding_vector_column(self) -> dict:
"""将嵌入向量列配置作为字典返回。
如果索引不是自管理的嵌入索引,则为空。
"""
index_spec = (
self._delta_sync_index_spec
if self._is_delta_sync_index()
else self._direct_access_index_spec
)
return next(iter(index_spec.get("embedding_vector_columns") or list()), dict())
def _embedding_source_column_name(self) -> Optional[str]:
"""返回嵌入源列的名称。
如果索引不是由Databricks管理的嵌入索引,则返回None。
"""
return self._embedding_source_column().get("name")
def _embedding_source_column(self) -> dict:
"""返回嵌入源列配置的字典。
如果索引不是由Databricks管理的嵌入索引,则为空。
"""
index_spec = self._delta_sync_index_spec
return next(iter(index_spec.get("embedding_source_columns") or list()), dict())
def _is_delta_sync_index(self) -> bool:
"""如果索引是增量同步索引,则返回True。"""
return self.index_type == "DELTA_SYNC"
def _is_direct_access_index(self) -> bool:
"""如果索引是直接访问索引,则返回True。"""
return self.index_type == "DIRECT_ACCESS"
def _is_databricks_managed_embeddings(self) -> bool:
"""如果嵌入由Databricks Vector Search 管理,则返回True。"""
return (
self._is_delta_sync_index()
and self._embedding_source_column_name() is not None
)
def _infer_embedding_dimension(self) -> int:
"""从嵌入函数中推断嵌入维度。"""
assert self.embeddings is not None, "embedding model is required."
return len(self.embeddings.embed_query("test"))
def _op_require_direct_access_index(self, op_name: str) -> None:
"""
如果不支持直接访问索引,则引发 ValueError。"""
if not self._is_direct_access_index():
raise ValueError(f"`{op_name}` is only supported for direct-access index.")
@staticmethod
def _require_arg(arg: Any, arg_name: str) -> None:
"""如果名称为`arg_name`的必需参数为None,则引发ValueError。"""
if not arg:
raise ValueError(f"`{arg_name}` is required for this index.")