Source code for langchain_community.vectorstores.hippo

from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple

from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore

    from transwarp_hippo_api.hippo_client import HippoClient

# Default connection
    "host": "localhost",
    "port": "7788",
    "username": "admin",
    "password": "admin",

logger = logging.getLogger(__name__)

[docs]class Hippo(VectorStore): """`Hippo` 向量存储。 您需要安装 `hippo-api` 并运行 Hippo。 请访问我们的官方网站了解如何运行 Hippo 实例: 参数: embedding_function (Embeddings): 用于嵌入文本的函数。 table_name (str): 要使用的 Hippo 表。默认为 "test"。 database_name (str): 要使用的 Hippo 数据库。默认为 "default"。 number_of_shards (int): Hippo 表的分片数。默认为 1。 number_of_replicas (int): Hippo 表的副本数。默认为 1。 connection_args (Optional[dict[str, any]]): 用于此类的连接参数以字典形式提供。 index_params (Optional[dict]): 要使用的索引参数。默认为 IVF_FLAT。 drop_old (Optional[bool]): 是否删除当前集合。默认为 False。 primary_field (str): 主键字段的名称。默认为 "pk"。 text_field (str): 文本字段的名称。默认为 "text"。 vector_field (str): 向量字段的名称。默认为 "vector"。 用于此类的连接参数以字典形式提供,以下是一些选项: host (str): Hippo 实例的主机。默认为 "localhost"。 port (str/int): Hippo 实例的端口。默认为 7788。 user (str): 用于连接到 Hippo 实例的用户。如果提供了用户和密码,我们将在每个 RPC 调用中添加相关的标头。 password (str): 当提供了用户时需要。与用户对应的密码。 示例: .. code-block:: python from langchain_community.vectorstores import Hippo from langchain_community.embeddings import OpenAIEmbeddings embedding = OpenAIEmbeddings() # 连接到本地主机上的 hippo 实例 vector_store = Hippo.from_documents( docs, embedding=embeddings, table_name="langchain_test", connection_args=HIPPO_CONNECTION ) 引发: ValueError: 如果未安装 hippo-api python 包。"""
[docs] def __init__( self, embedding_function: Embeddings, table_name: str = "test", database_name: str = "default", number_of_shards: int = 1, number_of_replicas: int = 1, connection_args: Optional[Dict[str, Any]] = None, index_params: Optional[dict] = None, drop_old: Optional[bool] = False, ): self.number_of_shards = number_of_shards self.number_of_replicas = number_of_replicas self.embedding_func = embedding_function self.table_name = table_name self.database_name = database_name self.index_params = index_params # In order for a collection to be compatible, # 'pk' should be an auto-increment primary key and string self._primary_field = "pk" # In order for compatibility, the text field will need to be called "text" self._text_field = "text" # In order for compatibility, the vector field needs to be called "vector" self._vector_field = "vector" self.fields: List[str] = [] # Create the connection to the server if connection_args is None: connection_args = DEFAULT_HIPPO_CONNECTION self.hc = self._create_connection_alias(connection_args) self.col: Any = None # If the collection exists, delete it try: if ( self.hc.check_table_exists(self.table_name, self.database_name) and drop_old ): self.hc.delete_table(self.table_name, self.database_name) except Exception as e: logging.error( f"An error occurred while deleting the table " f"{self.table_name}: {e}" ) raise try: if self.hc.check_table_exists(self.table_name, self.database_name): self.col = self.hc.get_table(self.table_name, self.database_name) except Exception as e: logging.error( f"An error occurred while getting the table " f"{self.table_name}: {e}" ) raise # Initialize the vector database self._get_env()
def _create_connection_alias(self, connection_args: dict) -> HippoClient: """创建与Hippo服务器的连接。""" # Grab the connection arguments that are used for checking existing connection try: from transwarp_hippo_api.hippo_client import HippoClient except ImportError as e: raise ImportError( "Unable to import transwarp_hipp_api, please install with " "`pip install hippo-api`." ) from e host: str = connection_args.get("host", None) port: int = connection_args.get("port", None) username: str = connection_args.get("username", "shiva") password: str = connection_args.get("password", "shiva") # Order of use is host/port, uri, address if host is not None and port is not None: if "," in host: hosts = host.split(",") given_address = ",".join([f"{h}:{port}" for h in hosts]) else: given_address = str(host) + ":" + str(port) else: raise ValueError("Missing standard address type for reuse attempt") try:"create HippoClient[{given_address}]") return HippoClient([given_address], username=username, pwd=password) except Exception as e: logger.error("Failed to create new connection") raise e def _get_env( self, embeddings: Optional[list] = None, metadatas: Optional[List[dict]] = None ) -> None:"init ...") if embeddings is not None:"create collection") self._create_collection(embeddings, metadatas) self._extract_fields() self._create_index() def _create_collection( self, embeddings: list, metadatas: Optional[List[dict]] = None ) -> None: from transwarp_hippo_api.hippo_client import HippoField from transwarp_hippo_api.hippo_type import HippoType # Determine embedding dim dim = len(embeddings[0]) logger.debug(f"[_create_collection] dim: {dim}") fields = [] # Create the primary key field fields.append(HippoField(self._primary_field, True, HippoType.STRING)) # Create the text field fields.append(HippoField(self._text_field, False, HippoType.STRING)) # Create the vector field, supports binary or float vectors # to The binary vector type is to be developed. fields.append( HippoField( self._vector_field, False, HippoType.FLOAT_VECTOR, type_params={"dimension": dim}, ) ) # to In Hippo,there is no method similar to the infer_type_data # types, so currently all non-vector data is converted to string type. if metadatas: # # Create FieldSchema for each entry in metadata. for key, value in metadatas[0].items(): # # Infer the corresponding datatype of the metadata if isinstance(value, list): value_dim = len(value) fields.append( HippoField( key, False, HippoType.FLOAT_VECTOR, type_params={"dimension": value_dim}, ) ) else: fields.append(HippoField(key, False, HippoType.STRING)) logger.debug(f"[_create_collection] fields: {fields}") # Create the collection self.hc.create_table( name=self.table_name, auto_id=True, fields=fields, database_name=self.database_name, number_of_shards=self.number_of_shards, number_of_replicas=self.number_of_replicas, ) self.col = self.hc.get_table(self.table_name, self.database_name) f"[_create_collection] : " f"create table {self.table_name} in {self.database_name} successfully" ) def _extract_fields(self) -> None: """从集合中获取现有字段""" from transwarp_hippo_api.hippo_client import HippoTable if isinstance(self.col, HippoTable): schema = self.col.schema logger.debug(f"[_extract_fields] schema:{schema}") for x in schema: self.fields.append( logger.debug(f"04 [_extract_fields] fields:{self.fields}") # TO CAN: Translated into English, your statement would be: "Currently, # only the field named 'vector' (the automatically created vector field) # is checked for indexing. Indexes need to be created manually for other # vector type columns. def _get_index(self) -> Optional[Dict[str, Any]]: """如果存在,返回向量索引信息""" from transwarp_hippo_api.hippo_client import HippoTable if isinstance(self.col, HippoTable): table_info = self.hc.get_table_info( self.table_name, self.database_name ).get(self.table_name, {}) embedding_indexes = table_info.get("embedding_indexes", None) if embedding_indexes is None: return None else: for x in self.hc.get_table_info(self.table_name, self.database_name)[ self.table_name ]["embedding_indexes"]: logger.debug(f"[_get_index] embedding_indexes {embedding_indexes}") if x["column"] == self._vector_field: return x return None # TO Indexes can only be created for the self._vector_field field. def _create_index(self) -> None: """在集合上创建一个索引""" from transwarp_hippo_api.hippo_client import HippoTable from transwarp_hippo_api.hippo_type import IndexType, MetricType if isinstance(self.col, HippoTable) and self._get_index() is None: if self._get_index() is None: if self.index_params is None: self.index_params = { "index_name": "langchain_auto_create", "metric_type": MetricType.L2, "index_type": IndexType.IVF_FLAT, "nlist": 10, } self.col.create_index( self._vector_field, self.index_params["index_name"], self.index_params["index_type"], self.index_params["metric_type"], nlist=self.index_params["nlist"], ) logger.debug( self.col.activate_index(self.index_params["index_name"]) )"create index successfully") else: index_dict = { "IVF_FLAT": IndexType.IVF_FLAT, "FLAT": IndexType.FLAT, "IVF_SQ": IndexType.IVF_SQ, "IVF_PQ": IndexType.IVF_PQ, "HNSW": IndexType.HNSW, } metric_dict = { "ip": MetricType.IP, "IP": MetricType.IP, "l2": MetricType.L2, "L2": MetricType.L2, } self.index_params["metric_type"] = metric_dict[ self.index_params["metric_type"] ] if self.index_params["index_type"] == "FLAT": self.index_params["index_type"] = index_dict[ self.index_params["index_type"] ] self.col.create_index( self._vector_field, self.index_params["index_name"], self.index_params["index_type"], self.index_params["metric_type"], ) logger.debug( self.col.activate_index(self.index_params["index_name"]) ) elif ( self.index_params["index_type"] == "IVF_FLAT" or self.index_params["index_type"] == "IVF_SQ" ): self.index_params["index_type"] = index_dict[ self.index_params["index_type"] ] self.col.create_index( self._vector_field, self.index_params["index_name"], self.index_params["index_type"], self.index_params["metric_type"], nlist=self.index_params.get("nlist", 10), nprobe=self.index_params.get("nprobe", 10), ) logger.debug( self.col.activate_index(self.index_params["index_name"]) ) elif self.index_params["index_type"] == "IVF_PQ": self.index_params["index_type"] = index_dict[ self.index_params["index_type"] ] self.col.create_index( self._vector_field, self.index_params["index_name"], self.index_params["index_type"], self.index_params["metric_type"], nlist=self.index_params.get("nlist", 10), nprobe=self.index_params.get("nprobe", 10), nbits=self.index_params.get("nbits", 8), m=self.index_params.get("m"), ) logger.debug( self.col.activate_index(self.index_params["index_name"]) ) elif self.index_params["index_type"] == "HNSW": self.index_params["index_type"] = index_dict[ self.index_params["index_type"] ] self.col.create_index( self._vector_field, self.index_params["index_name"], self.index_params["index_type"], self.index_params["metric_type"], M=self.index_params.get("M"), ef_construction=self.index_params.get("ef_construction"), ef_search=self.index_params.get("ef_search"), ) logger.debug( self.col.activate_index(self.index_params["index_name"]) ) else: raise ValueError( "Index name does not match, " "please enter the correct index name. " "(FLAT, IVF_FLAT, IVF_PQ,IVF_SQ, HNSW)" )
[docs] def add_texts( self, texts: Iterable[str], metadatas: Optional[List[dict]] = None, timeout: Optional[int] = None, batch_size: int = 1000, **kwargs: Any, ) -> List[str]: """将文本添加到集合中。 参数: texts: 包含要添加的文本的可迭代对象。 metadatas: 一个可选的字典列表, 每个字典包含与文本相关联的元数据。 timeout: 可选的超时时间,单位为秒。 batch_size: 每个批次插入的文本数量,默认为1000。 **kwargs: 其他可选参数。 返回: 一个字符串列表,包含插入文本的唯一标识符。 注意: 如果集合尚未创建, 此方法将创建一个新的集合。 """ from transwarp_hippo_api.hippo_client import HippoTable if not texts or all(t == "" for t in texts): logger.debug("Nothing to insert, skipping.") return [] texts = list(texts) logger.debug(f"[add_texts] texts: {texts}") try: embeddings = self.embedding_func.embed_documents(texts) except NotImplementedError: embeddings = [self.embedding_func.embed_query(x) for x in texts] if len(embeddings) == 0: logger.debug("Nothing to insert, skipping.") return [] logger.debug(f"[add_texts] len_embeddings:{len(embeddings)}") # 如果还没有创建collection则创建collection if not isinstance(self.col, HippoTable): self._get_env(embeddings, metadatas) # Dict to hold all insert columns insert_dict: Dict[str, list] = { self._text_field: texts, self._vector_field: embeddings, } logger.debug(f"[add_texts] metadatas:{metadatas}") logger.debug(f"[add_texts] fields:{self.fields}") if metadatas is not None: for d in metadatas: for key, value in d.items(): if key in self.fields: insert_dict.setdefault(key, []).append(value) logger.debug(insert_dict[self._text_field]) # Total insert count vectors: list = insert_dict[self._vector_field] total_count = len(vectors) if "pk" in self.fields: self.fields.remove("pk") logger.debug(f"[add_texts] total_count:{total_count}") for i in range(0, total_count, batch_size): # Grab end index end = min(i + batch_size, total_count) # Convert dict to list of lists batch for insertion insert_list = [insert_dict[x][i:end] for x in self.fields] try: res = self.col.insert_rows(insert_list)"05 [add_texts] insert {res}") except Exception as e: logger.error( "Failed to insert batch starting at entity: %s/%s", i, total_count ) raise e return [""]
[docs] def similarity_search_with_score( self, query: str, k: int = 4, param: Optional[dict] = None, expr: Optional[str] = None, timeout: Optional[int] = None, **kwargs: Any, ) -> List[Tuple[Document, float]]: """对查询字符串执行搜索,并返回带有分数的结果。 参数: query (str): 要搜索的文本。 k (int, optional): 要返回的结果数量。默认为4。 param (dict): 指定索引的搜索参数。默认为None。 expr (str, optional): 过滤表达式。默认为None。 timeout (int, optional): 超时错误前的等待时间。默认为None。 kwargs: 的关键字参数。 返回: List[float], List[Tuple[Document, any, any]]: """ if self.col is None: logger.debug("No existing collection to search.") return [] # Embed the query text. embedding = self.embedding_func.embed_query(query) ret = self.similarity_search_with_score_by_vector( embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs ) return ret
[docs] def similarity_search_with_score_by_vector( self, embedding: List[float], k: int = 4, param: Optional[dict] = None, expr: Optional[str] = None, timeout: Optional[int] = None, **kwargs: Any, ) -> List[Tuple[Document, float]]: """对查询字符串执行搜索,并返回带有分数的结果。 参数: embedding(List[float]):正在搜索的嵌入向量。 k(int,可选):要返回的结果数量。默认为4。 param(dict):指定索引的搜索参数。默认为None。 expr(str,可选):过滤表达式。默认为None。 timeout(int,可选):超时错误之前的等待时间。默认为None。的关键字参数。 返回: List[Tuple[Document, float]]:结果文档和分数。 """ if self.col is None: logger.debug("No existing collection to search.") return [] # if param is None: # param = self.search_params # Determine result metadata fields. output_fields = self.fields[:] output_fields.remove(self._vector_field) # Perform the search. logger.debug(f"search_field:{self._vector_field}") logger.debug(f"vectors:{[embedding]}") logger.debug(f"output_fields:{output_fields}") logger.debug(f"topk:{k}") logger.debug(f"dsl:{expr}") res = self.col.query( search_field=self._vector_field, vectors=[embedding], output_fields=output_fields, topk=k, dsl=expr, ) # Organize results. logger.debug(f"[similarity_search_with_score_by_vector] res:{res}") score_col = self._text_field + "%scores" ret = [] count = 0 for items in zip(*[res[0][field] for field in output_fields]): meta = {field: value for field, value in zip(output_fields, items)} doc = Document(page_content=meta.pop(self._text_field), metadata=meta) logger.debug( f"[similarity_search_with_score_by_vector] " f"res[0][score_col]:{res[0][score_col]}" ) score = res[0][score_col][count] count += 1 ret.append((doc, score)) return ret
[docs] @classmethod def from_texts( cls, texts: List[str], embedding: Embeddings, metadatas: Optional[List[dict]] = None, table_name: str = "test", database_name: str = "default", connection_args: Dict[str, Any] = DEFAULT_HIPPO_CONNECTION, index_params: Optional[Dict[Any, Any]] = None, search_params: Optional[Dict[str, Any]] = None, drop_old: bool = False, **kwargs: Any, ) -> "Hippo": """从给定的文本创建VST类的实例。 参数: texts (List[str]): 要添加的文本列表。 embedding (Embeddings): 文本的嵌入模型。 metadatas (List[dict], optional): 每个文本的元数据字典列表。默认为None。 table_name (str): 表的名称。默认为"test"。 database_name (str): 数据库的名称。默认为"default"。 connection_args (dict[str, Any]): 连接参数。 默认为DEFAULT_HIPPO_CONNECTION。 index_params (dict): 索引参数。默认为None。 search_params (dict): 搜索参数。默认为空字典。 drop_old (bool): 是否删除旧集合。默认为False。 kwargs: 其他参数。 返回: Hippo: VST类的实例。 """ if search_params is None: search_params = {}"00 [from_texts] init the class of Hippo") vector_db = cls( embedding_function=embedding, table_name=table_name, database_name=database_name, connection_args=connection_args, index_params=index_params, drop_old=drop_old, **kwargs, ) logger.debug(f"[from_texts] texts:{texts}") logger.debug(f"[from_texts] metadatas:{metadatas}") vector_db.add_texts(texts=texts, metadatas=metadatas) return vector_db