import uuid
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore
DEFAULT_DISTANCE_STRATEGY = "cosine" # or "l2", "inner_product"
DEFAULT_TiDB_VECTOR_TABLE_NAME = "langchain_vector"
[docs]class TiDBVectorStore(VectorStore):
"""TiDB 向量存储。"""
[docs] def __init__(
self,
connection_string: str,
embedding_function: Embeddings,
table_name: str = DEFAULT_TiDB_VECTOR_TABLE_NAME,
distance_strategy: str = DEFAULT_DISTANCE_STRATEGY,
*,
engine_args: Optional[Dict[str, Any]] = None,
drop_existing_table: bool = False,
**kwargs: Any,
) -> None:
"""在Langchain中使用灵活和标准化的表结构初始化TiDB向量存储,用于存储向量数据,无论动态表名设置如何,该结构保持不变。
向量表结构包括:
- 'id':每个条目的UUID。
- 'embedding':在VectorType列中存储向量数据。
- 'document':用于存储原始数据或附加信息的Text列。
- 'meta':用于灵活存储元数据的JSON列。
- 'create_time'和'update_time':用于跟踪数据更改的时间戳列。
该表结构适用于一般用例和复杂场景,其中表充当高级数据集成和分析的语义层,利用SQL进行连接查询。
参数:
connection_string (str):TiDB数据库的连接字符串,
格式:"mysql+pymysql://root@34.212.137.91:4000/test"。
embedding_function:用于生成嵌入的嵌入函数。
table_name (str, 可选):用于存储向量数据的表的名称。如果不提供表名,
将自动创建一个名为`langchain_vector`的默认表。
distance_strategy:用于相似性搜索的策略,默认为"cosine",有效值为:"l2"、"cosine"、"inner_product"。
engine_args (Optional[Dict]):数据库引擎的附加参数,默认为None。
drop_existing_table:在初始化之前删除现有的TiDB表,默认为False。
**kwargs (Any):其他关键字参数。
示例:
.. code-block:: python
from langchain_community.vectorstores import TiDBVectorStore
from langchain_openai import OpenAIEmbeddings
embeddingFunc = OpenAIEmbeddings()
CONNECTION_STRING = "mysql+pymysql://root@34.212.137.91:4000/test"
vs = TiDBVector.from_texts(
embedding=embeddingFunc,
texts = [..., ...],
connection_string=CONNECTION_STRING,
distance_strategy="l2",
table_name="tidb_vector_langchain",
)
query = "What did the president say about Ketanji Brown Jackson"
docs = db.similarity_search_with_score(query)
"""
super().__init__(**kwargs)
self._connection_string = connection_string
self._embedding_function = embedding_function
self._distance_strategy = distance_strategy
self._vector_dimension = self._get_dimension()
try:
from tidb_vector.integrations import TiDBVectorClient
except ImportError:
raise ImportError(
"Could not import tidbvec python package. "
"Please install it with `pip install tidb-vector`."
)
self._tidb = TiDBVectorClient(
connection_string=connection_string,
table_name=table_name,
distance_strategy=distance_strategy,
vector_dimension=self._vector_dimension,
engine_args=engine_args,
drop_existing_table=drop_existing_table,
**kwargs,
)
@property
def embeddings(self) -> Embeddings:
"""返回用于生成嵌入向量的函数。"""
return self._embedding_function
@property
def tidb_vector_client(self) -> Any:
"""返回 TiDB 向量客户端。"""
return self._tidb
@property
def distance_strategy(self) -> Any:
"""
返回当前的距离策略。
"""
return self._distance_strategy
def _get_dimension(self) -> int:
"""
使用嵌入函数获取向量的维度。
"""
return len(self._embedding_function.embed_query("test embedding length"))
[docs] @classmethod
def from_texts(
cls,
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> "TiDBVectorStore":
"""从文本列表创建一个VectorStore。
参数:
texts(List[str]):要添加到TiDB Vector的文本列表。
embedding(Embeddings):用于生成嵌入的函数。
metadatas:与每个文本对应的元数据字典列表,默认为None。
**kwargs(Any):额外的关键字参数。
connection_string(str):TiDB数据库的连接字符串,格式为:"mysql+pymysql://root@34.212.137.91:4000/test"。
table_name(str,可选):用于存储向量数据的表的名称,默认为"langchain_vector"。
distance_strategy:用于相似性搜索的距离策略,默认为"cosine",允许值:"l2","cosine","inner_product"。
ids(Optional[List[str]):与每个文本对应的ID列表,默认为None。
engine_args:底层数据库引擎的额外参数,默认为None。
drop_existing_table:在初始化之前删除现有的TiDB表,默认为False。
返回:
VectorStore:创建的TiDB Vector Store。
"""
# Extract arguments from kwargs with default values
connection_string = kwargs.pop("connection_string", None)
if connection_string is None:
raise ValueError("please provide your tidb connection_url")
table_name = kwargs.pop("table_name", "langchain_vector")
distance_strategy = kwargs.pop("distance_strategy", "cosine")
ids = kwargs.pop("ids", None)
engine_args = kwargs.pop("engine_args", None)
drop_existing_table = kwargs.pop("drop_existing_table", False)
embeddings = embedding.embed_documents(list(texts))
vs = cls(
connection_string=connection_string,
table_name=table_name,
embedding_function=embedding,
distance_strategy=distance_strategy,
engine_args=engine_args,
drop_existing_table=drop_existing_table,
**kwargs,
)
vs._tidb.insert(
texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs
)
return vs
[docs] @classmethod
def from_existing_vector_table(
cls,
embedding: Embeddings,
connection_string: str,
table_name: str,
distance_strategy: str = DEFAULT_DISTANCE_STRATEGY,
*,
engine_args: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> VectorStore:
"""从TiDB中的现有TiDB向量存储中创建一个VectorStore实例。
参数:
embedding (Embeddings):用于生成嵌入的函数。
connection_string (str):TiDB数据库的连接字符串,
格式:"mysql+pymysql://root@34.212.137.91:4000/test"。
table_name (str, 可选):用于存储向量数据的表的名称,
默认为"langchain_vector"。
distance_strategy:用于相似性搜索的距离策略,
默认为"cosine",允许值:"l2","cosine",'inner_product'。
engine_args:底层数据库引擎的附加参数,
默认为None。
**kwargs (Any):其他关键字参数。
返回:
VectorStore:VectorStore实例。
抛出:
NoSuchTableError:如果在TiDB中指定的表不存在。
"""
try:
from tidb_vector.integrations import check_table_existence
except ImportError:
raise ImportError(
"Could not import tidbvec python package. "
"Please install it with `pip install tidb-vector`."
)
if check_table_existence(connection_string, table_name):
return cls(
connection_string=connection_string,
table_name=table_name,
embedding_function=embedding,
distance_strategy=distance_strategy,
engine_args=engine_args,
**kwargs,
)
else:
raise ValueError(f"Table {table_name} does not exist in the TiDB database.")
[docs] def drop_vectorstore(self) -> None:
"""
从TiDB数据库中删除Vector Store。
"""
self._tidb.drop_table()
[docs] def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> List[str]:
"""将文本添加到 TiDB 向量存储中。
参数:
texts (Iterable[str]): 要添加的文本。
metadatas (Optional[List[dict]]): 与每个文本相关联的元数据,默认为 None。
ids (Optional[List[str]]): 要分配给每个文本的 ID,默认为 None,如果未提供将生成。
返回:
List[str]: 分配给添加的文本的 ID。
"""
embeddings = self._embedding_function.embed_documents(list(texts))
if ids is None:
ids = [str(uuid.uuid4()) for _ in texts]
if not metadatas:
metadatas = [{} for _ in texts]
return self._tidb.insert(
texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs
)
[docs] def delete(
self,
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""从TiDB矢量存储中删除矢量数据。
参数:
ids(可选[List[str]]):要删除的矢量ID列表。
**kwargs:额外的关键字参数。
"""
self._tidb.delete(ids=ids, **kwargs)
[docs] def similarity_search(
self,
query: str,
k: int = 4,
filter: Optional[dict] = None,
**kwargs: Any,
) -> List[Document]:
"""使用给定的查询执行相似性搜索。
参数:
query(str):查询字符串。
k(int,可选):要检索的结果数量。默认为4。
filter(dict,可选):要应用于搜索结果的过滤器。
默认为None。
**kwargs:其他关键字参数。
返回:
List[Document]:表示搜索结果的Document对象列表。
"""
result = self.similarity_search_with_score(query, k, filter, **kwargs)
return [doc for doc, _ in result]
[docs] def similarity_search_with_score(
self,
query: str,
k: int = 5,
filter: Optional[dict] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""执行基于给定查询的相似性搜索,并基于得分进行排序。
参数:
query(str):查询字符串。
k(int,可选):要返回的结果数量。默认为5。
filter(dict,可选):要应用于搜索结果的过滤器。默认为None。
**kwargs:其他关键字参数。
返回:
包含相关文档及其相似性分数的元组列表。
"""
query_vector = self._embedding_function.embed_query(query)
relevant_docs = self._tidb.query(
query_vector=query_vector, k=k, filter=filter, **kwargs
)
return [
(
Document(
page_content=doc.document,
metadata=doc.metadata,
),
doc.distance,
)
for doc in relevant_docs
]
def _select_relevance_score_fn(self) -> Callable[[float], float]:
"""
根据距离策略选择相关性评分函数。
"""
if self._distance_strategy == "cosine":
return self._cosine_relevance_score_fn
elif self._distance_strategy == "l2":
return self._euclidean_relevance_score_fn
else:
raise ValueError(
"No supported normalization function"
f" for distance_strategy of {self._distance_strategy}."
"Consider providing relevance_score_fn to PGVector constructor."
)