langchain_community.vectorstores.tencentvectordb ηš„ζΊδ»£η 

"""Wrapper around the Tencent vector database."""

from __future__ import annotations

import json
import logging
import time
from enum import Enum
from typing import (
    Any,
    Callable,
    Dict,
    Iterable,
    List,
    Optional,
    Sequence,
    Tuple,
    Union,
    cast,
)

import numpy as np
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.utils import guard_import
from langchain_core.vectorstores import VectorStore
from pydantic import BaseModel

from langchain_community.vectorstores.utils import maximal_marginal_relevance

logger = logging.getLogger(__name__)


META_FIELD_TYPE_UINT64 = "uint64"
META_FIELD_TYPE_STRING = "string"
META_FIELD_TYPE_ARRAY = "array"
META_FIELD_TYPE_VECTOR = "vector"

META_FIELD_TYPES = [
    META_FIELD_TYPE_UINT64,
    META_FIELD_TYPE_STRING,
    META_FIELD_TYPE_ARRAY,
    META_FIELD_TYPE_VECTOR,
]


[docs] class ConnectionParams: """Tencent vector DB Connection params. See the following documentation for details: https://cloud.tencent.com/document/product/1709/95820 Attribute: url (str) : The access address of the vector database server that the client needs to connect to. key (str): API key for client to access the vector database server, which is used for authentication. username (str) : Account for client to access the vector database server. timeout (int) : Request Timeout. """
[docs] def __init__(self, url: str, key: str, username: str = "root", timeout: int = 10): self.url = url self.key = key self.username = username self.timeout = timeout
[docs] class IndexParams: """Tencent vector DB Index params. See the following documentation for details: https://cloud.tencent.com/document/product/1709/95826 """
[docs] def __init__( self, dimension: int, shard: int = 1, replicas: int = 2, index_type: str = "HNSW", metric_type: str = "L2", params: Optional[Dict] = None, ): self.dimension = dimension self.shard = shard self.replicas = replicas self.index_type = index_type self.metric_type = metric_type self.params = params
[docs] class MetaField(BaseModel): """MetaData Field for Tencent vector DB.""" name: str description: Optional[str] data_type: Union[str, Enum] index: bool = False def __init__(self, **data: Any) -> None: super().__init__(**data) enum = guard_import("tcvectordb.model.enum") if isinstance(self.data_type, str): if self.data_type not in META_FIELD_TYPES: raise ValueError(f"unsupported data_type {self.data_type}") target = [ fe for fe in enum.FieldType if fe.value.lower() == self.data_type.lower() ] if target: self.data_type = target[0] else: raise ValueError(f"unsupported data_type {self.data_type}") else: if self.data_type not in enum.FieldType: raise ValueError(f"unsupported data_type {self.data_type}")
[docs] def translate_filter( lc_filter: str, allowed_fields: Optional[Sequence[str]] = None ) -> str: """Translate LangChain filter to Tencent VectorDB filter. Args: lc_filter (str): LangChain filter. allowed_fields (Optional[Sequence[str]]): Allowed fields for filter. Returns: str: Translated filter. """ from langchain.chains.query_constructor.base import fix_filter_directive from langchain.chains.query_constructor.parser import get_parser from langchain.retrievers.self_query.tencentvectordb import ( TencentVectorDBTranslator, ) from langchain_core.structured_query import FilterDirective tvdb_visitor = TencentVectorDBTranslator(allowed_fields) flt = cast( Optional[FilterDirective], get_parser( allowed_comparators=tvdb_visitor.allowed_comparators, allowed_operators=tvdb_visitor.allowed_operators, allowed_attributes=allowed_fields, ).parse(lc_filter), ) flt = fix_filter_directive(flt) return flt.accept(tvdb_visitor) if flt else ""
[docs] class TencentVectorDB(VectorStore): """Tencent VectorDB as a vector store. In order to use this you need to have a database instance. See the following documentation for details: https://cloud.tencent.com/document/product/1709/104489 """ field_id: str = "id" field_vector: str = "vector" field_text: str = "text" field_metadata: str = "metadata"
[docs] def __init__( self, embedding: Embeddings, connection_params: ConnectionParams, index_params: IndexParams = IndexParams(768), database_name: str = "LangChainDatabase", collection_name: str = "LangChainCollection", drop_old: Optional[bool] = False, collection_description: Optional[str] = "Collection for LangChain", meta_fields: Optional[List[MetaField]] = None, t_vdb_embedding: Optional[str] = "bge-base-zh", ): self.document = guard_import("tcvectordb.model.document") tcvectordb = guard_import("tcvectordb") tcollection = guard_import("tcvectordb.model.collection") enum = guard_import("tcvectordb.model.enum") self.embedding_model = None if embedding is None and t_vdb_embedding: embedding_model = [ model for model in enum.EmbeddingModel if t_vdb_embedding == model.model_name ] if not any(embedding_model): raise ValueError( f"embedding model `{t_vdb_embedding}` is invalid. " f"choices: {[member.model_name for member in enum.EmbeddingModel]}" ) self.embedding_model = tcollection.Embedding( vector_field="vector", field="text", model=embedding_model[0] ) self.embedding_func = embedding self.index_params = index_params self.collection_description = collection_description self.vdb_client = tcvectordb.VectorDBClient( url=connection_params.url, username=connection_params.username, key=connection_params.key, timeout=connection_params.timeout, ) self.meta_fields = meta_fields db_list = self.vdb_client.list_databases() db_exist: bool = False for db in db_list: if database_name == db.database_name: db_exist = True break if db_exist: self.database = self.vdb_client.database(database_name) else: self.database = self.vdb_client.create_database(database_name) try: self.collection = self.database.describe_collection(collection_name) if drop_old: self.database.drop_collection(collection_name) self._create_collection(collection_name) except tcvectordb.exceptions.VectorDBException: self._create_collection(collection_name)
def _create_collection(self, collection_name: str) -> None: enum = guard_import("tcvectordb.model.enum") vdb_index = guard_import("tcvectordb.model.index") index_type = enum.IndexType.__members__.get(self.index_params.index_type) if index_type is None: raise ValueError("unsupported index_type") metric_type = enum.MetricType.__members__.get(self.index_params.metric_type) if metric_type is None: raise ValueError("unsupported metric_type") params = vdb_index.HNSWParams( m=(self.index_params.params or {}).get("M", 16), efconstruction=(self.index_params.params or {}).get("efConstruction", 200), ) index = vdb_index.Index( vdb_index.FilterIndex( self.field_id, enum.FieldType.String, enum.IndexType.PRIMARY_KEY ), vdb_index.VectorIndex( self.field_vector, self.index_params.dimension, index_type, metric_type, params, ), vdb_index.FilterIndex( self.field_text, enum.FieldType.String, enum.IndexType.FILTER ), ) # Add metadata indexes if self.meta_fields is not None: index_meta_fields = [field for field in self.meta_fields if field.index] for field in index_meta_fields: ft_index = vdb_index.FilterIndex( field.name, field.data_type, enum.IndexType.FILTER ) index.add(ft_index) else: index.add( vdb_index.FilterIndex( self.field_metadata, enum.FieldType.String, enum.IndexType.FILTER ) ) self.collection = self.database.create_collection( name=collection_name, shard=self.index_params.shard, replicas=self.index_params.replicas, description=self.collection_description, index=index, embedding=self.embedding_model, ) @property def embeddings(self) -> Embeddings: return self.embedding_func
[docs] def delete( self, ids: Optional[List[str]] = None, filter_expr: Optional[str] = None, **kwargs: Any, ) -> Optional[bool]: """Delete documents from the collection.""" delete_attrs = {} if ids: delete_attrs["ids"] = ids if filter_expr: delete_attrs["filter"] = self.document.Filter(filter_expr) self.collection.delete(**delete_attrs) return True
[docs] @classmethod def from_texts( cls, texts: List[str], embedding: Embeddings, metadatas: Optional[List[dict]] = None, connection_params: Optional[ConnectionParams] = None, index_params: Optional[IndexParams] = None, database_name: str = "LangChainDatabase", collection_name: str = "LangChainCollection", drop_old: Optional[bool] = False, collection_description: Optional[str] = "Collection for LangChain", meta_fields: Optional[List[MetaField]] = None, t_vdb_embedding: Optional[str] = "bge-base-zh", **kwargs: Any, ) -> TencentVectorDB: """Create a collection, indexes it with HNSW, and insert data.""" if len(texts) == 0: raise ValueError("texts is empty") if connection_params is None: raise ValueError("connection_params is empty") enum = guard_import("tcvectordb.model.enum") if embedding is None and t_vdb_embedding is None: raise ValueError("embedding and t_vdb_embedding cannot be both None") if embedding: embeddings = embedding.embed_documents(texts[0:1]) dimension = len(embeddings[0]) else: embedding_model = [ model for model in enum.EmbeddingModel if t_vdb_embedding == model.model_name ] if not any(embedding_model): raise ValueError( f"embedding model `{t_vdb_embedding}` is invalid. " f"choices: {[member.model_name for member in enum.EmbeddingModel]}" ) dimension = embedding_model[0]._EmbeddingModel__dimensions if index_params is None: index_params = IndexParams(dimension=dimension) else: index_params.dimension = dimension vector_db = cls( embedding=embedding, connection_params=connection_params, index_params=index_params, database_name=database_name, collection_name=collection_name, drop_old=drop_old, collection_description=collection_description, meta_fields=meta_fields, t_vdb_embedding=t_vdb_embedding, ) vector_db.add_texts(texts=texts, metadatas=metadatas) return vector_db
[docs] def add_texts( self, texts: Iterable[str], metadatas: Optional[List[dict]] = None, timeout: Optional[int] = None, batch_size: int = 1000, ids: Optional[List[str]] = None, **kwargs: Any, ) -> List[str]: """Insert text data into TencentVectorDB.""" texts = list(texts) if len(texts) == 0: logger.debug("Nothing to insert, skipping.") return [] if self.embedding_func: embeddings = self.embedding_func.embed_documents(texts) else: embeddings = [] pks: list[str] = [] total_count = len(texts) for start in range(0, total_count, batch_size): # Grab end index docs = [] end = min(start + batch_size, total_count) for id in range(start, end, 1): metadata = ( self._get_meta(metadatas[id]) if metadatas and metadatas[id] else {} ) doc_id = ids[id] if ids else None doc_attrs: Dict[str, Any] = { "id": doc_id or "{}-{}-{}".format(time.time_ns(), hash(texts[id]), id) } if embeddings: doc_attrs["vector"] = embeddings[id] doc_attrs["text"] = texts[id] doc_attrs.update(metadata) doc = self.document.Document(**doc_attrs) docs.append(doc) pks.append(doc_attrs["id"]) self.collection.upsert(docs, timeout) return pks
[docs] def similarity_search_with_score( self, query: str, k: int = 4, param: Optional[dict] = None, expr: Optional[str] = None, timeout: Optional[int] = None, **kwargs: Any, ) -> List[Tuple[Document, float]]: """Perform a search on a query string and return results with score.""" # Embed the query text. if self.embedding_func: embedding = self.embedding_func.embed_query(query) return self.similarity_search_with_score_by_vector( embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs, ) return self.similarity_search_with_score_by_vector( embedding=[], k=k, param=param, expr=expr, timeout=timeout, query=query, **kwargs, )
[docs] def similarity_search_by_vector( self, embedding: List[float], k: int = 4, param: Optional[dict] = None, expr: Optional[str] = None, timeout: Optional[int] = None, **kwargs: Any, ) -> List[Document]: """Perform a similarity search against the query string.""" docs = self.similarity_search_with_score_by_vector( embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs ) return [doc for doc, _ in docs]
[docs] def similarity_search_with_score_by_vector( self, embedding: List[float], k: int = 4, param: Optional[dict] = None, expr: Optional[str] = None, filter: Optional[str] = None, timeout: Optional[int] = None, query: Optional[str] = None, **kwargs: Any, ) -> List[Tuple[Document, float]]: """Perform a search on a query string and return results with score.""" if filter and not expr: expr = translate_filter( filter, [f.name for f in (self.meta_fields or []) if f.index] ) search_args = { "filter": self.document.Filter(expr) if expr else None, "params": self.document.HNSWSearchParams(ef=(param or {}).get("ef", 10)), "retrieve_vector": False, "limit": k, "timeout": timeout, } if query: search_args["embeddingItems"] = [query] res: List[List[Dict]] = self.collection.searchByText(**search_args).get( "documents" ) else: search_args["vectors"] = [embedding] res = self.collection.search(**search_args) ret: List[Tuple[Document, float]] = [] if res is None or len(res) == 0: return ret for result in res[0]: meta = self._get_meta(result) doc = Document(page_content=result.get(self.field_text), metadata=meta) # type: ignore[arg-type] pair = (doc, result.get("score", 0.0)) ret.append(pair) return ret
def _get_meta(self, result: Dict) -> Dict: """Get metadata from the result.""" if self.meta_fields: return {field.name: result.get(field.name) for field in self.meta_fields} elif result.get(self.field_metadata): raw_meta = result.get(self.field_metadata) if raw_meta and isinstance(raw_meta, str): return json.loads(raw_meta) return {}
[docs] def max_marginal_relevance_search_by_vector( self, embedding: list[float], k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, param: Optional[dict] = None, expr: Optional[str] = None, filter: Optional[str] = None, timeout: Optional[int] = None, **kwargs: Any, ) -> List[Document]: """Perform a search and return results that are reordered by MMR.""" if filter and not expr: expr = translate_filter( filter, [f.name for f in (self.meta_fields or []) if f.index] ) res: List[List[Dict]] = self.collection.search( vectors=[embedding], filter=self.document.Filter(expr) if expr else None, params=self.document.HNSWSearchParams(ef=(param or {}).get("ef", 10)), retrieve_vector=True, limit=fetch_k, timeout=timeout, ) # Organize results. documents = [] ordered_result_embeddings = [] for result in res[0]: meta = self._get_meta(result) doc = Document(page_content=result.get(self.field_text), metadata=meta) # type: ignore[arg-type] documents.append(doc) ordered_result_embeddings.append(result.get(self.field_vector)) # Get the new order of results. new_ordering = maximal_marginal_relevance( np.array(embedding), ordered_result_embeddings, k=k, lambda_mult=lambda_mult ) # Reorder the values and return. return [documents[x] for x in new_ordering if x != -1]
def _select_relevance_score_fn(self) -> Callable[[float], float]: metric_type = self.index_params.metric_type if metric_type == "COSINE": return self._cosine_relevance_score_fn elif metric_type == "L2": return self._euclidean_relevance_score_fn elif metric_type == "IP": return self._max_inner_product_relevance_score_fn else: raise ValueError( "No supported normalization function" f" for distance metric of type: {metric_type}." )