Source code for langchain_community.vectorstores.myscale

from __future__ import annotations

import json
import logging
from hashlib import sha1
from threading import Thread
from typing import Any, Dict, Iterable, List, Optional, Tuple

from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseSettings
from langchain_core.vectorstores import VectorStore

logger = logging.getLogger()


[docs]def has_mul_sub_str(s: str, *args: Any) -> bool: """检查字符串是否包含多个子字符串。 参数: s:要检查的字符串。 *args:要检查的子字符串。 返回: 如果所有子字符串都在字符串中,则返回True,否则返回False。 """ for a in args: if a not in s: return False return True
[docs]class MyScaleSettings(BaseSettings): """MyScale客户端配置。 属性: myscale_host (str) : 用于连接到MyScale后端的URL。 默认为'localhost'。 myscale_port (int) : 用于通过HTTP连接的URL端口。默认为8443。 username (str) : 登录的用户名。默认为None。 password (str) : 登录的密码。默认为None。 index_type (str): 索引类型字符串。 index_param (dict): 索引构建参数。 database (str) : 要查找表的数据库名称。默认为'default'。 table (str) : 要操作的表名称。 默认为'vector_table'。 metric (str) : 计算距离的度量标准, 支持的有('L2', 'Cosine', 'IP')。默认为'Cosine'。 column_map (Dict) : 列类型映射,将列名投影到langchain语义上。 必须具有键:`text`,`id`,`vector`, 必须与列数相同。例如: .. code-block:: python { 'id': 'text_id', 'vector': 'text_embedding', 'text': 'text_plain', 'metadata': 'metadata_dictionary_in_json', } 默认为身份映射。""" host: str = "localhost" port: int = 8443 username: Optional[str] = None password: Optional[str] = None index_type: str = "MSTG" index_param: Optional[Dict[str, str]] = None column_map: Dict[str, str] = { "id": "id", "text": "text", "vector": "vector", "metadata": "metadata", } database: str = "default" table: str = "langchain" metric: str = "Cosine" def __getitem__(self, item: str) -> Any: return getattr(self, item) class Config: env_file = ".env" env_prefix = "myscale_" env_file_encoding = "utf-8"
[docs]class MyScale(VectorStore): """`MyScale`向量存储。 您需要一个`clickhouse-connect`的Python包,以及一个有效的账户 来连接到MyScale。 MyScale不仅可以使用简单的向量索引进行搜索。 它还支持具有多个条件、约束甚至子查询的复杂查询。 欲了解更多信息,请访问 [MyScale官方网站](https://docs.myscale.com/en/overview/)"""
[docs] def __init__( self, embedding: Embeddings, config: Optional[MyScaleSettings] = None, **kwargs: Any, ) -> None: """MyScale包装器到LangChain 嵌入(Embeddings): 配置(MyScaleSettings):MyScale客户端的配置 其他关键字参数将传递到 [clickhouse-connect](https://docs.myscale.com/) """ try: from clickhouse_connect import get_client except ImportError: raise ImportError( "Could not import clickhouse connect python package. " "Please install it with `pip install clickhouse-connect`." ) try: from tqdm import tqdm self.pgbar = tqdm except ImportError: # Just in case if tqdm is not installed self.pgbar = lambda x: x super().__init__() if config is not None: self.config = config else: self.config = MyScaleSettings() assert self.config assert self.config.host and self.config.port assert ( self.config.column_map and self.config.database and self.config.table and self.config.metric ) for k in ["id", "vector", "text", "metadata"]: assert k in self.config.column_map assert self.config.metric.upper() in ["IP", "COSINE", "L2"] if self.config.metric in ["ip", "cosine", "l2"]: logger.warning( "Lower case metric types will be deprecated " "the future. Please use one of ('IP', 'Cosine', 'L2')" ) # initialize the schema dim = len(embedding.embed_query("try this out")) index_params = ( ", " + ",".join([f"'{k}={v}'" for k, v in self.config.index_param.items()]) if self.config.index_param else "" ) schema_ = f""" CREATE TABLE IF NOT EXISTS {self.config.database}.{self.config.table}( {self.config.column_map['id']} String, {self.config.column_map['text']} String, {self.config.column_map['vector']} Array(Float32), {self.config.column_map['metadata']} JSON, CONSTRAINT cons_vec_len CHECK length(\ {self.config.column_map['vector']}) = {dim}, VECTOR INDEX vidx {self.config.column_map['vector']} \ TYPE {self.config.index_type}(\ 'metric_type={self.config.metric}'{index_params}) ) ENGINE = MergeTree ORDER BY {self.config.column_map['id']} """ self.dim = dim self.BS = "\\" self.must_escape = ("\\", "'") self._embeddings = embedding self.dist_order = ( "ASC" if self.config.metric.upper() in ["COSINE", "L2"] else "DESC" ) # Create a connection to myscale self.client = get_client( host=self.config.host, port=self.config.port, username=self.config.username, password=self.config.password, **kwargs, ) self.client.command("SET allow_experimental_object_type=1") self.client.command(schema_)
@property def embeddings(self) -> Embeddings: return self._embeddings
[docs] def escape_str(self, value: str) -> str: return "".join(f"{self.BS}{c}" if c in self.must_escape else c for c in value)
def _build_istr(self, transac: Iterable, column_names: Iterable[str]) -> str: ks = ",".join(column_names) _data = [] for n in transac: n = ",".join([f"'{self.escape_str(str(_n))}'" for _n in n]) _data.append(f"({n})") i_str = f""" INSERT INTO TABLE {self.config.database}.{self.config.table}({ks}) VALUES {','.join(_data)} """ return i_str def _insert(self, transac: Iterable, column_names: Iterable[str]) -> None: _i_str = self._build_istr(transac, column_names) self.client.command(_i_str)
[docs] def add_texts( self, texts: Iterable[str], metadatas: Optional[List[dict]] = None, batch_size: int = 32, ids: Optional[Iterable[str]] = None, **kwargs: Any, ) -> List[str]: """运行更多文本通过嵌入并添加到向量存储。 参数: texts:要添加到向量存储的字符串的可迭代对象。 ids:可选的与文本关联的id列表。 batch_size:插入的批量大小 metadata:要插入的可选列数据 返回: 将文本添加到向量存储后的id列表。 """ # Embed and create the documents ids = ids or [sha1(t.encode("utf-8")).hexdigest() for t in texts] colmap_ = self.config.column_map transac = [] column_names = { colmap_["id"]: ids, colmap_["text"]: texts, colmap_["vector"]: map(self._embeddings.embed_query, texts), } metadatas = metadatas or [{} for _ in texts] column_names[colmap_["metadata"]] = map(json.dumps, metadatas) assert len(set(colmap_) - set(column_names)) >= 0 keys, values = zip(*column_names.items()) try: t = None for v in self.pgbar( zip(*values), desc="Inserting data...", total=len(metadatas) ): assert len(v[keys.index(self.config.column_map["vector"])]) == self.dim transac.append(v) if len(transac) == batch_size: if t: t.join() t = Thread(target=self._insert, args=[transac, keys]) t.start() transac = [] if len(transac) > 0: if t: t.join() self._insert(transac, keys) return [i for i in ids] except Exception as e: logger.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m") return []
[docs] @classmethod def from_texts( cls, texts: Iterable[str], embedding: Embeddings, metadatas: Optional[List[Dict[Any, Any]]] = None, config: Optional[MyScaleSettings] = None, text_ids: Optional[Iterable[str]] = None, batch_size: int = 32, **kwargs: Any, ) -> MyScale: """创建一个使用现有文本的Myscale包装器 参数: texts (Iterable[str]): 要添加的字符串列表或元组 embedding (Embeddings): 用于提取文本嵌入的函数 config (MyScaleSettings, Optional): Myscale配置 text_ids (Optional[Iterable], optional): 文本的ID。默认为None。 batch_size (int, optional): 传输数据到Myscale时的批处理大小。默认为32。 metadata (List[dict], optional): 文本的元数据。默认为None。 其他关键字参数将传递到 [clickhouse-connect](https://clickhouse.com/docs/en/integrations/python#clickhouse-connect-driver-api) 返回: MyScale索引 """ ctx = cls(embedding, config, **kwargs) ctx.add_texts(texts, ids=text_ids, batch_size=batch_size, metadatas=metadatas) return ctx
def __repr__(self) -> str: """文本表示为myscale,打印后端、用户名和模式。 # 通过`str(Myscale())`很容易使用 返回: # repr: 显示连接信息和数据模式的字符串 """ _repr = f"\033[92m\033[1m{self.config.database}.{self.config.table} @ " _repr += f"{self.config.host}:{self.config.port}\033[0m\n\n" _repr += f"\033[1musername: {self.config.username}\033[0m\n\nTable Schema:\n" _repr += "-" * 51 + "\n" for r in self.client.query( f"DESC {self.config.database}.{self.config.table}" ).named_results(): _repr += ( f"|\033[94m{r['name']:24s}\033[0m|\033[96m{r['type']:24s}\033[0m|\n" ) _repr += "-" * 51 + "\n" return _repr def _build_qstr( self, q_emb: List[float], topk: int, where_str: Optional[str] = None ) -> str: q_emb_str = ",".join(map(str, q_emb)) if where_str: where_str = f"PREWHERE {where_str}" else: where_str = "" q_str = f""" SELECT {self.config.column_map['text']}, {self.config.column_map['metadata']}, dist FROM {self.config.database}.{self.config.table} {where_str} ORDER BY distance({self.config.column_map['vector']}, [{q_emb_str}]) AS dist {self.dist_order} LIMIT {topk} """ return q_str
[docs] def similarity_search_by_vector( self, embedding: List[float], k: int = 4, where_str: Optional[str] = None, **kwargs: Any, ) -> List[Document]: """使用MyScale通过向量执行相似性搜索 参数: query (str): 查询字符串 k (int, optional): 要检索的前K个邻居。默认为4。 where_str (Optional[str], optional): where条件字符串。 默认为None。 注意: 请不要让最终用户填写这个内容,并始终注意SQL注入问题。 处理元数据时,请记住使用`{self.metadata_column}.attribute`而不是仅使用`attribute`。 其默认名称为`metadata`。 返回: List[Document]: (Document, 相似度)的列表 """ q_str = self._build_qstr(embedding, k, where_str) try: return [ Document( page_content=r[self.config.column_map["text"]], metadata=r[self.config.column_map["metadata"]], ) for r in self.client.query(q_str).named_results() ] except Exception as e: logger.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m") return []
[docs] def similarity_search_with_relevance_scores( self, query: str, k: int = 4, where_str: Optional[str] = None, **kwargs: Any ) -> List[Tuple[Document, float]]: """使用MyScale进行相似性搜索 参数: query (str): 查询字符串 k (int, optional): 要检索的前K个最近邻居。默认为4。 where_str (Optional[str], optional): where条件字符串。 默认为None。 注意: 请不要让最终用户填写此内容,并始终注意SQL注入的问题。 处理元数据时,请记住使用`{self.metadata_column}.attribute`而不是仅使用`attribute`。 其默认名称为`metadata`。 返回: List[Document]: 与查询文本最相似的文档列表,以及每个文档的余弦距离。 较低的分数表示更相似。 """ q_str = self._build_qstr(self._embeddings.embed_query(query), k, where_str) try: return [ ( Document( page_content=r[self.config.column_map["text"]], metadata=r[self.config.column_map["metadata"]], ), r["dist"], ) for r in self.client.query(q_str).named_results() ] except Exception as e: logger.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m") return []
[docs] def drop(self) -> None: """ 辅助函数:丢弃数据 """ self.client.command( f"DROP TABLE IF EXISTS {self.config.database}.{self.config.table}" )
[docs] def delete( self, ids: Optional[List[str]] = None, where_str: Optional[str] = None, **kwargs: Any, ) -> Optional[bool]: """根据向量ID或其他条件删除。 参数: ids:要删除的ID列表。 **kwargs:子类可能使用的其他关键字参数。 返回: Optional[bool]:如果删除成功则为True,否则为False,如果未实现则为None。 """ assert not ( ids is None and where_str is None ), "You need to specify where to be deleted! Either with `ids` or `where_str`" conds = [] if ids and len(ids) > 0: id_list = ", ".join([f"'{id}'" for id in ids]) conds.append(f"{self.config.column_map['id']} IN ({id_list})") if where_str: conds.append(where_str) assert len(conds) > 0 where_str_final = " AND ".join(conds) qstr = ( f"DELETE FROM {self.config.database}.{self.config.table} " f"WHERE {where_str_final}" ) try: self.client.command(qstr) return True except Exception as e: logger.error(str(e)) return False
@property def metadata_column(self) -> str: return self.config.column_map["metadata"]
[docs]class MyScaleWithoutJSON(MyScale): """我的规模向量存储没有元数据列 如果您正在处理一个SQL本地表,这将非常方便"""
[docs] def __init__( self, embedding: Embeddings, config: Optional[MyScaleSettings] = None, must_have_cols: List[str] = [], **kwargs: Any, ) -> None: """构建一个没有元数据列的myscale向量存储 embedding (嵌入): 嵌入模型 config (MyScaleSettings): MyScale客户端的配置 must_have_cols (List[str]): 查询中要包含的列名 其他关键字参数将传递到 [clickhouse-connect](https://docs.myscale.com/) """ super().__init__(embedding, config, **kwargs) self.must_have_cols: List[str] = must_have_cols
def _build_qstr( self, q_emb: List[float], topk: int, where_str: Optional[str] = None ) -> str: q_emb_str = ",".join(map(str, q_emb)) if where_str: where_str = f"PREWHERE {where_str}" else: where_str = "" q_str = f""" SELECT {self.config.column_map['text']}, dist, {','.join(self.must_have_cols)} FROM {self.config.database}.{self.config.table} {where_str} ORDER BY distance({self.config.column_map['vector']}, [{q_emb_str}]) AS dist {self.dist_order} LIMIT {topk} """ return q_str
[docs] def similarity_search_by_vector( self, embedding: List[float], k: int = 4, where_str: Optional[str] = None, **kwargs: Any, ) -> List[Document]: """使用MyScale通过向量执行相似性搜索 参数: query (str): 查询字符串 k (int, optional): 要检索的前K个邻居。默认为4。 where_str (Optional[str], optional): where条件字符串。 默认为None。 注意: 请不要让最终用户填写这个内容,并始终注意SQL注入问题。 处理元数据时,请记住使用`{self.metadata_column}.attribute`而不是仅使用`attribute`。 其默认名称为`metadata`。 返回: List[Document]: (Document, 相似度)的列表 """ q_str = self._build_qstr(embedding, k, where_str) try: return [ Document( page_content=r[self.config.column_map["text"]], metadata={k: r[k] for k in self.must_have_cols}, ) for r in self.client.query(q_str).named_results() ] except Exception as e: logger.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m") return []
[docs] def similarity_search_with_relevance_scores( self, query: str, k: int = 4, where_str: Optional[str] = None, **kwargs: Any ) -> List[Tuple[Document, float]]: """使用MyScale进行相似性搜索 参数: query (str): 查询字符串 k (int, optional): 要检索的前K个最近邻居。默认为4。 where_str (Optional[str], optional): where条件字符串。 默认为None。 注意: 请不要让最终用户填写此内容,并始终注意SQL注入的问题。 处理元数据时,请记住使用`{self.metadata_column}.attribute`而不是仅使用`attribute`。 其默认名称为`metadata`。 返回: List[Document]: 与查询文本最相似的文档列表,以及每个文档的余弦距离。 较低的分数表示更相似。 """ q_str = self._build_qstr(self._embeddings.embed_query(query), k, where_str) try: return [ ( Document( page_content=r[self.config.column_map["text"]], metadata={k: r[k] for k in self.must_have_cols}, ), r["dist"], ) for r in self.client.query(q_str).named_results() ] except Exception as e: logger.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m") return []
@property def metadata_column(self) -> str: return ""