Source code for langchain_community.vectorstores.databricks_vector_search

from __future__ import annotations

import json
import logging
import uuid
from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Optional, Tuple, Type

import numpy as np
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VST, VectorStore

from langchain_community.vectorstores.utils import maximal_marginal_relevance

if TYPE_CHECKING:
    from databricks.vector_search.client import VectorSearchIndex

logger = logging.getLogger(__name__)


[docs]class DatabricksVectorSearch(VectorStore): """`Databricks Vector Search` 向量存储。 要使用,应安装 ``databricks-vectorsearch`` python 包。 示例: .. code-block:: python from langchain_community.vectorstores import DatabricksVectorSearch from databricks.vector_search.client import VectorSearchClient vs_client = VectorSearchClient() vs_index = vs_client.get_index( endpoint_name="vs_endpoint", index_name="ml.llm.index" ) vectorstore = DatabricksVectorSearch(vs_index) 参数: index: 一个 Databricks Vector Search 索引对象。 embedding: 嵌入模型。 对于直接访问索引或具有自管理嵌入的增量同步索引,需要。 text_column: 用于嵌入的文本列的名称。 对于直接访问索引或具有自管理嵌入的增量同步索引,需要。 确保指定的文本列在索引中。 columns: 在进行搜索时要获取的列名列表。 默认为 ``[primary_key, text_column]``。 具有由 Databricks 管理的嵌入的增量同步索引会为您管理摄入、删除和嵌入。 不支持手动摄入/删除文档/文本以用于增量同步索引。 如果要使用具有自管理嵌入的增量同步索引,需要提供嵌入模型和要用于嵌入的文本列名称。 示例: .. code-block:: python from langchain_community.vectorstores import DatabricksVectorSearch from databricks.vector_search.client import VectorSearchClient from langchain_community.embeddings.openai import OpenAIEmbeddings vs_client = VectorSearchClient() vs_index = vs_client.get_index( endpoint_name="vs_endpoint", index_name="ml.llm.index" ) vectorstore = DatabricksVectorSearch( index=vs_index, embedding=OpenAIEmbeddings(), text_column="document_content" ) 如果要自行管理文档的摄入/删除,可以使用直接访问索引。 示例: .. code-block:: python from langchain_community.vectorstores import DatabricksVectorSearch from databricks.vector_search.client import VectorSearchClient from langchain_community.embeddings.openai import OpenAIEmbeddings vs_client = VectorSearchClient() vs_index = vs_client.get_index( endpoint_name="vs_endpoint", index_name="ml.llm.index" ) vectorstore = DatabricksVectorSearch( index=vs_index, embedding=OpenAIEmbeddings(), text_column="document_content" ) vectorstore.add_texts( texts=["text1", "text2"] ) 有关 Databricks Vector Search 的更多信息,请参阅 `Databricks Vector Search 文档: https://docs.databricks.com/en/generative-ai/vector-search.html."""
[docs] def __init__( self, index: VectorSearchIndex, *, embedding: Optional[Embeddings] = None, text_column: Optional[str] = None, columns: Optional[List[str]] = None, ): try: from databricks.vector_search.client import VectorSearchIndex except ImportError as e: raise ImportError( "Could not import databricks-vectorsearch python package. " "Please install it with `pip install databricks-vectorsearch`." ) from e # index self.index = index if not isinstance(index, VectorSearchIndex): raise TypeError("index must be of type VectorSearchIndex.") # index_details index_details = self.index.describe() self.primary_key = index_details["primary_key"] self.index_type = index_details.get("index_type") self._delta_sync_index_spec = index_details.get("delta_sync_index_spec", dict()) self._direct_access_index_spec = index_details.get( "direct_access_index_spec", dict() ) # text_column if self._is_databricks_managed_embeddings(): index_source_column = self._embedding_source_column_name() # check if input text column matches the source column of the index if text_column is not None and text_column != index_source_column: raise ValueError( f"text_column '{text_column}' does not match with the " f"source column of the index: '{index_source_column}'." ) self.text_column = index_source_column else: self._require_arg(text_column, "text_column") self.text_column = text_column # columns self.columns = columns or [] # add primary key column and source column if not in columns if self.primary_key not in self.columns: self.columns.append(self.primary_key) if self.text_column and self.text_column not in self.columns: self.columns.append(self.text_column) # Validate specified columns are in the index if self._is_direct_access_index(): index_schema = self._index_schema() if index_schema: for col in self.columns: if col not in index_schema: raise ValueError( f"column '{col}' is not in the index's schema." ) # embedding model if not self._is_databricks_managed_embeddings(): # embedding model is required for direct-access index # or delta-sync index with self-managed embedding self._require_arg(embedding, "embedding") self._embedding = embedding # validate dimension matches index_embedding_dimension = self._embedding_vector_column_dimension() if index_embedding_dimension is not None: inferred_embedding_dimension = self._infer_embedding_dimension() if inferred_embedding_dimension != index_embedding_dimension: raise ValueError( f"embedding model's dimension '{inferred_embedding_dimension}' " f"does not match with the index's dimension " f"'{index_embedding_dimension}'." ) else: if embedding is not None: logger.warning( "embedding model is not used in delta-sync index with " "Databricks-managed embeddings." ) self._embedding = None
[docs] @classmethod def from_texts( cls: Type[VST], texts: List[str], embedding: Embeddings, metadatas: Optional[List[dict]] = None, **kwargs: Any, ) -> VST: raise NotImplementedError( "`from_texts` is not supported. " "Use `add_texts` to add to existing direct-access index." )
[docs] def add_texts( self, texts: Iterable[str], metadatas: Optional[List[dict]] = None, ids: Optional[List[Any]] = None, **kwargs: Any, ) -> List[str]: """将文本添加到索引中。 仅支持直接访问索引。 参数: texts:要添加的文本列表。 metadatas:每个文本的元数据列表。默认为None。 ids:每个文本的id列表。默认为None。 如果未提供,则将为每个文本生成一个随机uuid。 返回: 将文本添加到索引后的id列表。 """ self._op_require_direct_access_index("add_texts") assert self.embeddings is not None, "embedding model is required." # Wrap to list if input texts is a single string if isinstance(texts, str): texts = [texts] texts = list(texts) vectors = self.embeddings.embed_documents(texts) ids = ids or [str(uuid.uuid4()) for _ in texts] metadatas = metadatas or [{} for _ in texts] updates = [ { self.primary_key: id_, self.text_column: text, self._embedding_vector_column_name(): vector, **metadata, } for text, vector, id_, metadata in zip(texts, vectors, ids, metadatas) ] upsert_resp = self.index.upsert(updates) if upsert_resp.get("status") in ("PARTIAL_SUCCESS", "FAILURE"): failed_ids = upsert_resp.get("result", dict()).get( "failed_primary_keys", [] ) if upsert_resp.get("status") == "FAILURE": logger.error("Failed to add texts to the index.") else: logger.warning("Some texts failed to be added to the index.") return [id_ for id_ in ids if id_ not in failed_ids] return ids
@property def embeddings(self) -> Optional[Embeddings]: """如果可用,访问查询嵌入对象。""" return self._embedding
[docs] def delete(self, ids: Optional[List[Any]] = None, **kwargs: Any) -> Optional[bool]: """从索引中删除文档。 仅支持直接访问索引。 参数: ids: 要删除的文档的id列表。 返回: 如果成功则返回True。 """ self._op_require_direct_access_index("delete") if ids is None: raise ValueError("ids must be provided.") self.index.delete(ids) return True
[docs] def similarity_search_with_score( self, query: str, k: int = 4, filters: Optional[Any] = None, **kwargs: Any ) -> List[Tuple[Document, float]]: """返回与查询最相似的文档,以及分数。 参数: query:要查找类似文档的文本。 k:要返回的文档数量。默认为4。 filters:要应用于查询的过滤器。默认为None。 返回: 返回与嵌入最相似的文档列表,以及每个文档的分数。 """ if self._is_databricks_managed_embeddings(): query_text = query query_vector = None else: assert self.embeddings is not None, "embedding model is required." query_text = None query_vector = self.embeddings.embed_query(query) search_resp = self.index.similarity_search( columns=self.columns, query_text=query_text, query_vector=query_vector, filters=filters, num_results=k, ) return self._parse_search_response(search_resp)
@staticmethod def _identity_fn(score: float) -> float: return score def _select_relevance_score_fn(self) -> Callable[[float], float]: """Databricks矢量搜索使用标准化得分1/(1+d),其中d是L2距离。因此,我们简单地返回恒等函数。 """ return self._identity_fn
[docs] def max_marginal_relevance_search_by_vector( self, embedding: List[float], k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, filters: Optional[Any] = None, **kwargs: Any, ) -> List[Document]: """返回使用最大边际相关性选择的文档。 最大边际相关性优化了与查询的相似性和所选文档之间的多样性。 参数: embedding: 用于查找类似文档的嵌入。 k: 要返回的文档数量。默认为4。 fetch_k: 要获取的文档数量以传递给MMR算法。 lambda_mult: 0到1之间的数字,确定结果之间多样性的程度,其中0对应最大多样性,1对应最小多样性。默认为0.5。 filters: 要应用于查询的过滤器。默认为无。 返回: 通过最大边际相关性选择的文档列表。 """ if not self._is_databricks_managed_embeddings(): embedding_column = self._embedding_vector_column_name() else: raise ValueError( "`max_marginal_relevance_search` is not supported for index with " "Databricks-managed embeddings." ) search_resp = self.index.similarity_search( columns=list(set(self.columns + [embedding_column])), query_text=None, query_vector=embedding, filters=filters, num_results=fetch_k, ) embeddings_result_index = ( search_resp.get("manifest").get("columns").index({"name": embedding_column}) ) embeddings = [ doc[embeddings_result_index] for doc in search_resp.get("result").get("data_array") ] mmr_selected = maximal_marginal_relevance( np.array(embedding, dtype=np.float32), embeddings, k=k, lambda_mult=lambda_mult, ) ignore_cols: List = ( [embedding_column] if embedding_column not in self.columns else [] ) candidates = self._parse_search_response(search_resp, ignore_cols=ignore_cols) selected_results = [r[0] for i, r in enumerate(candidates) if i in mmr_selected] return selected_results
[docs] def similarity_search_by_vector( self, embedding: List[float], k: int = 4, filters: Optional[Any] = None, **kwargs: Any, ) -> List[Document]: """返回与嵌入向量最相似的文档。 参数: embedding: 要查找相似文档的嵌入。 k: 要返回的文档数量。默认为4。 filters: 要应用于查询的过滤器。默认为None。 返回: 与嵌入最相似的文档列表。 """ docs_with_score = self.similarity_search_by_vector_with_score( embedding=embedding, k=k, filters=filters, **kwargs ) return [doc for doc, _ in docs_with_score]
[docs] def similarity_search_by_vector_with_score( self, embedding: List[float], k: int = 4, filters: Optional[Any] = None, **kwargs: Any, ) -> List[Tuple[Document, float]]: """返回与嵌入向量最相似的文档,以及相似度分数。 参数: embedding:要查找相似文档的嵌入。 k:要返回的文档数量。默认为4。 filters:要应用于查询的过滤器。默认为None。 返回: 返回与嵌入最相似的文档列表,以及每个文档的分数。 """ if self._is_databricks_managed_embeddings(): raise ValueError( "`similarity_search_by_vector` is not supported for index with " "Databricks-managed embeddings." ) search_resp = self.index.similarity_search( columns=self.columns, query_vector=embedding, filters=filters, num_results=k, ) return self._parse_search_response(search_resp)
def _parse_search_response( self, search_resp: dict, ignore_cols: Optional[List[str]] = None ) -> List[Tuple[Document, float]]: """将搜索响应解析为带有分数的文档列表。""" if ignore_cols is None: ignore_cols = [] columns = [ col["name"] for col in search_resp.get("manifest", dict()).get("columns", []) ] docs_with_score = [] for result in search_resp.get("result", dict()).get("data_array", []): doc_id = result[columns.index(self.primary_key)] text_content = result[columns.index(self.text_column)] metadata = { col: value for col, value in zip(columns[:-1], result[:-1]) if col not in ([self.primary_key, self.text_column] + ignore_cols) } metadata[self.primary_key] = doc_id score = result[-1] doc = Document(page_content=text_content, metadata=metadata) docs_with_score.append((doc, score)) return docs_with_score def _index_schema(self) -> Optional[dict]: """返回索引模式作为字典。 如果找不到模式,则返回None。 """ if self._is_direct_access_index(): schema_json = self._direct_access_index_spec.get("schema_json") if schema_json is not None: return json.loads(schema_json) return None def _embedding_vector_column_name(self) -> Optional[str]: """返回嵌入向量列的名称。 如果索引不是自管理的嵌入索引,则返回None。 """ return self._embedding_vector_column().get("name") def _embedding_vector_column_dimension(self) -> Optional[int]: """返回嵌入向量列的维度。 如果索引不是自管理的嵌入索引,则返回None。 """ return self._embedding_vector_column().get("embedding_dimension") def _embedding_vector_column(self) -> dict: """将嵌入向量列配置作为字典返回。 如果索引不是自管理的嵌入索引,则为空。 """ index_spec = ( self._delta_sync_index_spec if self._is_delta_sync_index() else self._direct_access_index_spec ) return next(iter(index_spec.get("embedding_vector_columns") or list()), dict()) def _embedding_source_column_name(self) -> Optional[str]: """返回嵌入源列的名称。 如果索引不是由Databricks管理的嵌入索引,则返回None。 """ return self._embedding_source_column().get("name") def _embedding_source_column(self) -> dict: """返回嵌入源列配置的字典。 如果索引不是由Databricks管理的嵌入索引,则为空。 """ index_spec = self._delta_sync_index_spec return next(iter(index_spec.get("embedding_source_columns") or list()), dict()) def _is_delta_sync_index(self) -> bool: """如果索引是增量同步索引,则返回True。""" return self.index_type == "DELTA_SYNC" def _is_direct_access_index(self) -> bool: """如果索引是直接访问索引,则返回True。""" return self.index_type == "DIRECT_ACCESS" def _is_databricks_managed_embeddings(self) -> bool: """如果嵌入由Databricks Vector Search 管理,则返回True。""" return ( self._is_delta_sync_index() and self._embedding_source_column_name() is not None ) def _infer_embedding_dimension(self) -> int: """从嵌入函数中推断嵌入维度。""" assert self.embeddings is not None, "embedding model is required." return len(self.embeddings.embed_query("test")) def _op_require_direct_access_index(self, op_name: str) -> None: """ 如果不支持直接访问索引,则引发 ValueError。""" if not self._is_direct_access_index(): raise ValueError(f"`{op_name}` is only supported for direct-access index.") @staticmethod def _require_arg(arg: Any, arg_name: str) -> None: """如果名称为`arg_name`的必需参数为None,则引发ValueError。""" if not arg: raise ValueError(f"`{arg_name}` is required for this index.")