from __future__ import annotations
import logging
import uuid
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Type
from sqlalchemy import REAL, Column, String, Table, create_engine, insert, text
from sqlalchemy.dialects.postgresql import ARRAY, JSON, TEXT
try:
from sqlalchemy.orm import declarative_base
except ImportError:
from sqlalchemy.ext.declarative import declarative_base
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.utils import get_from_dict_or_env
from langchain_core.vectorstores import VectorStore
_LANGCHAIN_DEFAULT_EMBEDDING_DIM = 1536
_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain_document"
Base = declarative_base() # type: Any
[docs]class AnalyticDB(VectorStore):
"""`AnalyticDB`(分布式PostgreSQL)向量存储。
AnalyticDB是一个分布式完整的PostgreSQL语法云原生数据库。
- `connection_string` 是一个PostgreSQL连接字符串。
- `embedding_function` 是实现`langchain.embeddings.base.Embeddings`接口的任何嵌入函数。
- `collection_name` 是要使用的集合的名称。(默认值:langchain)
- 注意:这不是表的名称,而是集合的名称。
表将在初始化存储时创建(如果不存在)。
因此,请确保用户有创建表的权限。
- `pre_delete_collection` 如果为True,则会删除该集合(如果存在)。
(默认值:False)
- 用于测试。"""
[docs] def __init__(
self,
connection_string: str,
embedding_function: Embeddings,
embedding_dimension: int = _LANGCHAIN_DEFAULT_EMBEDDING_DIM,
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
pre_delete_collection: bool = False,
logger: Optional[logging.Logger] = None,
engine_args: Optional[dict] = None,
) -> None:
self.connection_string = connection_string
self.embedding_function = embedding_function
self.embedding_dimension = embedding_dimension
self.collection_name = collection_name
self.pre_delete_collection = pre_delete_collection
self.logger = logger or logging.getLogger(__name__)
self.__post_init__(engine_args)
def __post_init__(
self,
engine_args: Optional[dict] = None,
) -> None:
"""
初始化商店。
"""
_engine_args = engine_args or {}
if (
"pool_recycle" not in _engine_args
): # Check if pool_recycle is not in _engine_args
_engine_args[
"pool_recycle"
] = 3600 # Set pool_recycle to 3600s if not present
self.engine = create_engine(self.connection_string, **_engine_args)
self.create_collection()
@property
def embeddings(self) -> Embeddings:
return self.embedding_function
def _select_relevance_score_fn(self) -> Callable[[float], float]:
return self._euclidean_relevance_score_fn
[docs] def create_table_if_not_exists(self) -> None:
# Define the dynamic table
Table(
self.collection_name,
Base.metadata,
Column("id", TEXT, primary_key=True, default=uuid.uuid4),
Column("embedding", ARRAY(REAL)),
Column("document", String, nullable=True),
Column("metadata", JSON, nullable=True),
extend_existing=True,
)
with self.engine.connect() as conn:
with conn.begin():
# Create the table
Base.metadata.create_all(conn)
# Check if the index exists
index_name = f"{self.collection_name}_embedding_idx"
index_query = text(
f"""
SELECT 1
FROM pg_indexes
WHERE indexname = '{index_name}';
"""
)
result = conn.execute(index_query).scalar()
# Create the index if it doesn't exist
if not result:
index_statement = text(
f"""
CREATE INDEX {index_name}
ON {self.collection_name} USING ann(embedding)
WITH (
"dim" = {self.embedding_dimension},
"hnsw_m" = 100
);
"""
)
conn.execute(index_statement)
[docs] def create_collection(self) -> None:
if self.pre_delete_collection:
self.delete_collection()
self.create_table_if_not_exists()
[docs] def delete_collection(self) -> None:
self.logger.debug("Trying to delete collection")
drop_statement = text(f"DROP TABLE IF EXISTS {self.collection_name};")
with self.engine.connect() as conn:
with conn.begin():
conn.execute(drop_statement)
[docs] def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
batch_size: int = 500,
**kwargs: Any,
) -> List[str]:
"""运行更多文本通过嵌入并添加到向量存储。
参数:
texts:要添加到向量存储的字符串的可迭代对象。
metadatas:与文本相关联的元数据的可选列表。
kwargs:向量存储特定参数
返回:
将文本添加到向量存储中的ID列表。
"""
if ids is None:
ids = [str(uuid.uuid4()) for _ in texts]
embeddings = self.embedding_function.embed_documents(list(texts))
if not metadatas:
metadatas = [{} for _ in texts]
# Define the table schema
chunks_table = Table(
self.collection_name,
Base.metadata,
Column("id", TEXT, primary_key=True),
Column("embedding", ARRAY(REAL)),
Column("document", String, nullable=True),
Column("metadata", JSON, nullable=True),
extend_existing=True,
)
chunks_table_data = []
with self.engine.connect() as conn:
with conn.begin():
for document, metadata, chunk_id, embedding in zip(
texts, metadatas, ids, embeddings
):
chunks_table_data.append(
{
"id": chunk_id,
"embedding": embedding,
"document": document,
"metadata": metadata,
}
)
# Execute the batch insert when the batch size is reached
if len(chunks_table_data) == batch_size:
conn.execute(insert(chunks_table).values(chunks_table_data))
# Clear the chunks_table_data list for the next batch
chunks_table_data.clear()
# Insert any remaining records that didn't make up a full batch
if chunks_table_data:
conn.execute(insert(chunks_table).values(chunks_table_data))
return ids
[docs] def similarity_search(
self,
query: str,
k: int = 4,
filter: Optional[dict] = None,
**kwargs: Any,
) -> List[Document]:
"""在AnalyticDB中运行带有距离的相似性搜索。
参数:
query (str): 要搜索的查询文本。
k (int): 要返回的结果数量。默认为4。
filter (Optional[Dict[str, str]]): 按元数据进行过滤。默认为None。
返回:
与查询最相似的文档列表。
"""
embedding = self.embedding_function.embed_query(text=query)
return self.similarity_search_by_vector(
embedding=embedding,
k=k,
filter=filter,
)
[docs] def similarity_search_with_score(
self,
query: str,
k: int = 4,
filter: Optional[dict] = None,
) -> List[Tuple[Document, float]]:
"""返回与查询最相似的文档。
参数:
query:要查找相似文档的文本。
k:要返回的文档数量。默认为4。
filter(可选[Dict[str,str]]):按元数据过滤。默认为无。
返回:
返回与查询最相似的文档列表,以及每个文档的得分。
"""
embedding = self.embedding_function.embed_query(query)
docs = self.similarity_search_with_score_by_vector(
embedding=embedding, k=k, filter=filter
)
return docs
[docs] def similarity_search_with_score_by_vector(
self,
embedding: List[float],
k: int = 4,
filter: Optional[dict] = None,
) -> List[Tuple[Document, float]]:
# Add the filter if provided
try:
from sqlalchemy.engine import Row
except ImportError:
raise ImportError(
"Could not import Row from sqlalchemy.engine. "
"Please 'pip install sqlalchemy>=1.4'."
)
filter_condition = ""
if filter is not None:
conditions = [
f"metadata->>{key!r} = {value!r}" for key, value in filter.items()
]
filter_condition = f"WHERE {' AND '.join(conditions)}"
# Define the base query
sql_query = f"""
SELECT *, l2_distance(embedding, :embedding) as distance
FROM {self.collection_name}
{filter_condition}
ORDER BY embedding <-> :embedding
LIMIT :k
"""
# Set up the query parameters
params = {"embedding": embedding, "k": k}
# Execute the query and fetch the results
with self.engine.connect() as conn:
results: Sequence[Row] = conn.execute(text(sql_query), params).fetchall()
documents_with_scores = [
(
Document(
page_content=result.document,
metadata=result.metadata,
),
result.distance if self.embedding_function is not None else None,
)
for result in results
]
return documents_with_scores
[docs] def similarity_search_by_vector(
self,
embedding: List[float],
k: int = 4,
filter: Optional[dict] = None,
**kwargs: Any,
) -> List[Document]:
"""返回与嵌入向量最相似的文档。
参数:
embedding: 要查找与之相似文档的嵌入。
k: 要返回的文档数量。默认为4。
filter (Optional[Dict[str, str]]): 按元数据过滤。默认为None。
返回:
与查询向量最相似的文档列表。
"""
docs_and_scores = self.similarity_search_with_score_by_vector(
embedding=embedding, k=k, filter=filter
)
return [doc for doc, _ in docs_and_scores]
[docs] def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
"""根据向量ID删除。
参数:
ids:要删除的ID列表。
"""
if ids is None:
raise ValueError("No ids provided to delete.")
# Define the table schema
chunks_table = Table(
self.collection_name,
Base.metadata,
Column("id", TEXT, primary_key=True),
Column("embedding", ARRAY(REAL)),
Column("document", String, nullable=True),
Column("metadata", JSON, nullable=True),
extend_existing=True,
)
try:
with self.engine.connect() as conn:
with conn.begin():
delete_condition = chunks_table.c.id.in_(ids)
conn.execute(chunks_table.delete().where(delete_condition))
return True
except Exception as e:
print("Delete operation failed:", str(e)) # noqa: T201
return False
[docs] @classmethod
def from_texts(
cls: Type[AnalyticDB],
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
embedding_dimension: int = _LANGCHAIN_DEFAULT_EMBEDDING_DIM,
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
ids: Optional[List[str]] = None,
pre_delete_collection: bool = False,
engine_args: Optional[dict] = None,
**kwargs: Any,
) -> AnalyticDB:
"""返回从文本和嵌入初始化的VectorStore。
需要Postgres连接字符串。
可以将其作为参数传递,
或设置PG_CONNECTION_STRING环境变量。
"""
connection_string = cls.get_connection_string(kwargs)
store = cls(
connection_string=connection_string,
collection_name=collection_name,
embedding_function=embedding,
embedding_dimension=embedding_dimension,
pre_delete_collection=pre_delete_collection,
engine_args=engine_args,
)
store.add_texts(texts=texts, metadatas=metadatas, ids=ids, **kwargs)
return store
[docs] @classmethod
def get_connection_string(cls, kwargs: Dict[str, Any]) -> str:
connection_string: str = get_from_dict_or_env(
data=kwargs,
key="connection_string",
env_key="PG_CONNECTION_STRING",
)
if not connection_string:
raise ValueError(
"Postgres connection string is required"
"Either pass it as a parameter"
"or set the PG_CONNECTION_STRING environment variable."
)
return connection_string
[docs] @classmethod
def from_documents(
cls: Type[AnalyticDB],
documents: List[Document],
embedding: Embeddings,
embedding_dimension: int = _LANGCHAIN_DEFAULT_EMBEDDING_DIM,
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
ids: Optional[List[str]] = None,
pre_delete_collection: bool = False,
engine_args: Optional[dict] = None,
**kwargs: Any,
) -> AnalyticDB:
"""返回从文档和嵌入初始化的VectorStore。
需要Postgres连接字符串
可以作为参数传递
或设置PG_CONNECTION_STRING环境变量。
"""
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]
connection_string = cls.get_connection_string(kwargs)
kwargs["connection_string"] = connection_string
return cls.from_texts(
texts=texts,
pre_delete_collection=pre_delete_collection,
embedding=embedding,
embedding_dimension=embedding_dimension,
metadatas=metadatas,
ids=ids,
collection_name=collection_name,
engine_args=engine_args,
**kwargs,
)
[docs] @classmethod
def connection_string_from_db_params(
cls,
driver: str,
host: str,
port: int,
database: str,
user: str,
password: str,
) -> str:
"""从数据库参数返回连接字符串。"""
return f"postgresql+{driver}://{user}:{password}@{host}:{port}/{database}"