from __future__ import annotations
import json
import logging
from hashlib import sha1
from threading import Thread
from typing import Any, Dict, Iterable, List, Optional, Tuple
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseSettings
from langchain_core.vectorstores import VectorStore
logger = logging.getLogger()
[docs]def has_mul_sub_str(s: str, *args: Any) -> bool:
"""检查字符串是否包含多个子字符串。
参数:
s:要检查的字符串。
*args:要检查的子字符串。
返回:
如果所有子字符串都在字符串中,则返回True,否则返回False。
"""
for a in args:
if a not in s:
return False
return True
[docs]class MyScaleSettings(BaseSettings):
"""MyScale客户端配置。
属性:
myscale_host (str) : 用于连接到MyScale后端的URL。
默认为'localhost'。
myscale_port (int) : 用于通过HTTP连接的URL端口。默认为8443。
username (str) : 登录的用户名。默认为None。
password (str) : 登录的密码。默认为None。
index_type (str): 索引类型字符串。
index_param (dict): 索引构建参数。
database (str) : 要查找表的数据库名称。默认为'default'。
table (str) : 要操作的表名称。
默认为'vector_table'。
metric (str) : 计算距离的度量标准,
支持的有('L2', 'Cosine', 'IP')。默认为'Cosine'。
column_map (Dict) : 列类型映射,将列名投影到langchain语义上。
必须具有键:`text`,`id`,`vector`,
必须与列数相同。例如:
.. code-block:: python
{
'id': 'text_id',
'vector': 'text_embedding',
'text': 'text_plain',
'metadata': 'metadata_dictionary_in_json',
}
默认为身份映射。"""
host: str = "localhost"
port: int = 8443
username: Optional[str] = None
password: Optional[str] = None
index_type: str = "MSTG"
index_param: Optional[Dict[str, str]] = None
column_map: Dict[str, str] = {
"id": "id",
"text": "text",
"vector": "vector",
"metadata": "metadata",
}
database: str = "default"
table: str = "langchain"
metric: str = "Cosine"
def __getitem__(self, item: str) -> Any:
return getattr(self, item)
class Config:
env_file = ".env"
env_prefix = "myscale_"
env_file_encoding = "utf-8"
[docs]class MyScale(VectorStore):
"""`MyScale`向量存储。
您需要一个`clickhouse-connect`的Python包,以及一个有效的账户
来连接到MyScale。
MyScale不仅可以使用简单的向量索引进行搜索。
它还支持具有多个条件、约束甚至子查询的复杂查询。
欲了解更多信息,请访问
[MyScale官方网站](https://docs.myscale.com/en/overview/)"""
[docs] def __init__(
self,
embedding: Embeddings,
config: Optional[MyScaleSettings] = None,
**kwargs: Any,
) -> None:
"""MyScale包装器到LangChain
嵌入(Embeddings):
配置(MyScaleSettings):MyScale客户端的配置
其他关键字参数将传递到
[clickhouse-connect](https://docs.myscale.com/)
"""
try:
from clickhouse_connect import get_client
except ImportError:
raise ImportError(
"Could not import clickhouse connect python package. "
"Please install it with `pip install clickhouse-connect`."
)
try:
from tqdm import tqdm
self.pgbar = tqdm
except ImportError:
# Just in case if tqdm is not installed
self.pgbar = lambda x: x
super().__init__()
if config is not None:
self.config = config
else:
self.config = MyScaleSettings()
assert self.config
assert self.config.host and self.config.port
assert (
self.config.column_map
and self.config.database
and self.config.table
and self.config.metric
)
for k in ["id", "vector", "text", "metadata"]:
assert k in self.config.column_map
assert self.config.metric.upper() in ["IP", "COSINE", "L2"]
if self.config.metric in ["ip", "cosine", "l2"]:
logger.warning(
"Lower case metric types will be deprecated "
"the future. Please use one of ('IP', 'Cosine', 'L2')"
)
# initialize the schema
dim = len(embedding.embed_query("try this out"))
index_params = (
", " + ",".join([f"'{k}={v}'" for k, v in self.config.index_param.items()])
if self.config.index_param
else ""
)
schema_ = f"""
CREATE TABLE IF NOT EXISTS {self.config.database}.{self.config.table}(
{self.config.column_map['id']} String,
{self.config.column_map['text']} String,
{self.config.column_map['vector']} Array(Float32),
{self.config.column_map['metadata']} JSON,
CONSTRAINT cons_vec_len CHECK length(\
{self.config.column_map['vector']}) = {dim},
VECTOR INDEX vidx {self.config.column_map['vector']} \
TYPE {self.config.index_type}(\
'metric_type={self.config.metric}'{index_params})
) ENGINE = MergeTree ORDER BY {self.config.column_map['id']}
"""
self.dim = dim
self.BS = "\\"
self.must_escape = ("\\", "'")
self._embeddings = embedding
self.dist_order = (
"ASC" if self.config.metric.upper() in ["COSINE", "L2"] else "DESC"
)
# Create a connection to myscale
self.client = get_client(
host=self.config.host,
port=self.config.port,
username=self.config.username,
password=self.config.password,
**kwargs,
)
self.client.command("SET allow_experimental_object_type=1")
self.client.command(schema_)
@property
def embeddings(self) -> Embeddings:
return self._embeddings
[docs] def escape_str(self, value: str) -> str:
return "".join(f"{self.BS}{c}" if c in self.must_escape else c for c in value)
def _build_istr(self, transac: Iterable, column_names: Iterable[str]) -> str:
ks = ",".join(column_names)
_data = []
for n in transac:
n = ",".join([f"'{self.escape_str(str(_n))}'" for _n in n])
_data.append(f"({n})")
i_str = f"""
INSERT INTO TABLE
{self.config.database}.{self.config.table}({ks})
VALUES
{','.join(_data)}
"""
return i_str
def _insert(self, transac: Iterable, column_names: Iterable[str]) -> None:
_i_str = self._build_istr(transac, column_names)
self.client.command(_i_str)
[docs] def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
batch_size: int = 32,
ids: Optional[Iterable[str]] = None,
**kwargs: Any,
) -> List[str]:
"""运行更多文本通过嵌入并添加到向量存储。
参数:
texts:要添加到向量存储的字符串的可迭代对象。
ids:可选的与文本关联的id列表。
batch_size:插入的批量大小
metadata:要插入的可选列数据
返回:
将文本添加到向量存储后的id列表。
"""
# Embed and create the documents
ids = ids or [sha1(t.encode("utf-8")).hexdigest() for t in texts]
colmap_ = self.config.column_map
transac = []
column_names = {
colmap_["id"]: ids,
colmap_["text"]: texts,
colmap_["vector"]: map(self._embeddings.embed_query, texts),
}
metadatas = metadatas or [{} for _ in texts]
column_names[colmap_["metadata"]] = map(json.dumps, metadatas)
assert len(set(colmap_) - set(column_names)) >= 0
keys, values = zip(*column_names.items())
try:
t = None
for v in self.pgbar(
zip(*values), desc="Inserting data...", total=len(metadatas)
):
assert len(v[keys.index(self.config.column_map["vector"])]) == self.dim
transac.append(v)
if len(transac) == batch_size:
if t:
t.join()
t = Thread(target=self._insert, args=[transac, keys])
t.start()
transac = []
if len(transac) > 0:
if t:
t.join()
self._insert(transac, keys)
return [i for i in ids]
except Exception as e:
logger.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
return []
[docs] @classmethod
def from_texts(
cls,
texts: Iterable[str],
embedding: Embeddings,
metadatas: Optional[List[Dict[Any, Any]]] = None,
config: Optional[MyScaleSettings] = None,
text_ids: Optional[Iterable[str]] = None,
batch_size: int = 32,
**kwargs: Any,
) -> MyScale:
"""创建一个使用现有文本的Myscale包装器
参数:
texts (Iterable[str]): 要添加的字符串列表或元组
embedding (Embeddings): 用于提取文本嵌入的函数
config (MyScaleSettings, Optional): Myscale配置
text_ids (Optional[Iterable], optional): 文本的ID。默认为None。
batch_size (int, optional): 传输数据到Myscale时的批处理大小。默认为32。
metadata (List[dict], optional): 文本的元数据。默认为None。
其他关键字参数将传递到
[clickhouse-connect](https://clickhouse.com/docs/en/integrations/python#clickhouse-connect-driver-api)
返回:
MyScale索引
"""
ctx = cls(embedding, config, **kwargs)
ctx.add_texts(texts, ids=text_ids, batch_size=batch_size, metadatas=metadatas)
return ctx
def __repr__(self) -> str:
"""文本表示为myscale,打印后端、用户名和模式。
# 通过`str(Myscale())`很容易使用
返回:
# repr: 显示连接信息和数据模式的字符串
"""
_repr = f"\033[92m\033[1m{self.config.database}.{self.config.table} @ "
_repr += f"{self.config.host}:{self.config.port}\033[0m\n\n"
_repr += f"\033[1musername: {self.config.username}\033[0m\n\nTable Schema:\n"
_repr += "-" * 51 + "\n"
for r in self.client.query(
f"DESC {self.config.database}.{self.config.table}"
).named_results():
_repr += (
f"|\033[94m{r['name']:24s}\033[0m|\033[96m{r['type']:24s}\033[0m|\n"
)
_repr += "-" * 51 + "\n"
return _repr
def _build_qstr(
self, q_emb: List[float], topk: int, where_str: Optional[str] = None
) -> str:
q_emb_str = ",".join(map(str, q_emb))
if where_str:
where_str = f"PREWHERE {where_str}"
else:
where_str = ""
q_str = f"""
SELECT {self.config.column_map['text']},
{self.config.column_map['metadata']}, dist
FROM {self.config.database}.{self.config.table}
{where_str}
ORDER BY distance({self.config.column_map['vector']}, [{q_emb_str}])
AS dist {self.dist_order}
LIMIT {topk}
"""
return q_str
[docs] def similarity_search(
self, query: str, k: int = 4, where_str: Optional[str] = None, **kwargs: Any
) -> List[Document]:
"""使用MyScale执行相似性搜索
参数:
query (str): 查询字符串
k (int, optional): 要检索的前K个邻居。默认为4。
where_str (Optional[str], optional): where条件字符串。
默认为None。
注意:请不要让最终用户填写这个内容,并始终注意SQL注入问题。在处理元数据时,请记住使用`{self.metadata_column}.attribute`而不是仅使用`attribute`。它的默认名称是`metadata`。
返回:
List[Document]: 文档列表
"""
return self.similarity_search_by_vector(
self._embeddings.embed_query(query), k, where_str, **kwargs
)
[docs] def similarity_search_by_vector(
self,
embedding: List[float],
k: int = 4,
where_str: Optional[str] = None,
**kwargs: Any,
) -> List[Document]:
"""使用MyScale通过向量执行相似性搜索
参数:
query (str): 查询字符串
k (int, optional): 要检索的前K个邻居。默认为4。
where_str (Optional[str], optional): where条件字符串。
默认为None。
注意: 请不要让最终用户填写这个内容,并始终注意SQL注入问题。
处理元数据时,请记住使用`{self.metadata_column}.attribute`而不是仅使用`attribute`。
其默认名称为`metadata`。
返回:
List[Document]: (Document, 相似度)的列表
"""
q_str = self._build_qstr(embedding, k, where_str)
try:
return [
Document(
page_content=r[self.config.column_map["text"]],
metadata=r[self.config.column_map["metadata"]],
)
for r in self.client.query(q_str).named_results()
]
except Exception as e:
logger.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
return []
[docs] def similarity_search_with_relevance_scores(
self, query: str, k: int = 4, where_str: Optional[str] = None, **kwargs: Any
) -> List[Tuple[Document, float]]:
"""使用MyScale进行相似性搜索
参数:
query (str): 查询字符串
k (int, optional): 要检索的前K个最近邻居。默认为4。
where_str (Optional[str], optional): where条件字符串。
默认为None。
注意: 请不要让最终用户填写此内容,并始终注意SQL注入的问题。
处理元数据时,请记住使用`{self.metadata_column}.attribute`而不是仅使用`attribute`。
其默认名称为`metadata`。
返回:
List[Document]: 与查询文本最相似的文档列表,以及每个文档的余弦距离。
较低的分数表示更相似。
"""
q_str = self._build_qstr(self._embeddings.embed_query(query), k, where_str)
try:
return [
(
Document(
page_content=r[self.config.column_map["text"]],
metadata=r[self.config.column_map["metadata"]],
),
r["dist"],
)
for r in self.client.query(q_str).named_results()
]
except Exception as e:
logger.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
return []
[docs] def drop(self) -> None:
"""
辅助函数:丢弃数据
"""
self.client.command(
f"DROP TABLE IF EXISTS {self.config.database}.{self.config.table}"
)
[docs] def delete(
self,
ids: Optional[List[str]] = None,
where_str: Optional[str] = None,
**kwargs: Any,
) -> Optional[bool]:
"""根据向量ID或其他条件删除。
参数:
ids:要删除的ID列表。
**kwargs:子类可能使用的其他关键字参数。
返回:
Optional[bool]:如果删除成功则为True,否则为False,如果未实现则为None。
"""
assert not (
ids is None and where_str is None
), "You need to specify where to be deleted! Either with `ids` or `where_str`"
conds = []
if ids and len(ids) > 0:
id_list = ", ".join([f"'{id}'" for id in ids])
conds.append(f"{self.config.column_map['id']} IN ({id_list})")
if where_str:
conds.append(where_str)
assert len(conds) > 0
where_str_final = " AND ".join(conds)
qstr = (
f"DELETE FROM {self.config.database}.{self.config.table} "
f"WHERE {where_str_final}"
)
try:
self.client.command(qstr)
return True
except Exception as e:
logger.error(str(e))
return False
@property
def metadata_column(self) -> str:
return self.config.column_map["metadata"]
[docs]class MyScaleWithoutJSON(MyScale):
"""我的规模向量存储没有元数据列
如果您正在处理一个SQL本地表,这将非常方便"""
[docs] def __init__(
self,
embedding: Embeddings,
config: Optional[MyScaleSettings] = None,
must_have_cols: List[str] = [],
**kwargs: Any,
) -> None:
"""构建一个没有元数据列的myscale向量存储
embedding (嵌入): 嵌入模型
config (MyScaleSettings): MyScale客户端的配置
must_have_cols (List[str]): 查询中要包含的列名
其他关键字参数将传递到
[clickhouse-connect](https://docs.myscale.com/)
"""
super().__init__(embedding, config, **kwargs)
self.must_have_cols: List[str] = must_have_cols
def _build_qstr(
self, q_emb: List[float], topk: int, where_str: Optional[str] = None
) -> str:
q_emb_str = ",".join(map(str, q_emb))
if where_str:
where_str = f"PREWHERE {where_str}"
else:
where_str = ""
q_str = f"""
SELECT {self.config.column_map['text']}, dist,
{','.join(self.must_have_cols)}
FROM {self.config.database}.{self.config.table}
{where_str}
ORDER BY distance({self.config.column_map['vector']}, [{q_emb_str}])
AS dist {self.dist_order}
LIMIT {topk}
"""
return q_str
[docs] def similarity_search_by_vector(
self,
embedding: List[float],
k: int = 4,
where_str: Optional[str] = None,
**kwargs: Any,
) -> List[Document]:
"""使用MyScale通过向量执行相似性搜索
参数:
query (str): 查询字符串
k (int, optional): 要检索的前K个邻居。默认为4。
where_str (Optional[str], optional): where条件字符串。
默认为None。
注意: 请不要让最终用户填写这个内容,并始终注意SQL注入问题。
处理元数据时,请记住使用`{self.metadata_column}.attribute`而不是仅使用`attribute`。
其默认名称为`metadata`。
返回:
List[Document]: (Document, 相似度)的列表
"""
q_str = self._build_qstr(embedding, k, where_str)
try:
return [
Document(
page_content=r[self.config.column_map["text"]],
metadata={k: r[k] for k in self.must_have_cols},
)
for r in self.client.query(q_str).named_results()
]
except Exception as e:
logger.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
return []
[docs] def similarity_search_with_relevance_scores(
self, query: str, k: int = 4, where_str: Optional[str] = None, **kwargs: Any
) -> List[Tuple[Document, float]]:
"""使用MyScale进行相似性搜索
参数:
query (str): 查询字符串
k (int, optional): 要检索的前K个最近邻居。默认为4。
where_str (Optional[str], optional): where条件字符串。
默认为None。
注意: 请不要让最终用户填写此内容,并始终注意SQL注入的问题。
处理元数据时,请记住使用`{self.metadata_column}.attribute`而不是仅使用`attribute`。
其默认名称为`metadata`。
返回:
List[Document]: 与查询文本最相似的文档列表,以及每个文档的余弦距离。
较低的分数表示更相似。
"""
q_str = self._build_qstr(self._embeddings.embed_query(query), k, where_str)
try:
return [
(
Document(
page_content=r[self.config.column_map["text"]],
metadata={k: r[k] for k in self.must_have_cols},
),
r["dist"],
)
for r in self.client.query(q_str).named_results()
]
except Exception as e:
logger.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
return []
@property
def metadata_column(self) -> str:
return ""