from __future__ import annotations
import contextlib
import enum
import logging
import uuid
from typing import (
Any,
Callable,
Dict,
Generator,
Iterable,
List,
Optional,
Tuple,
Type,
Union,
)
import numpy as np
import sqlalchemy
from sqlalchemy import delete, func
from sqlalchemy.dialects.postgresql import JSON, UUID
from sqlalchemy.exc import ProgrammingError
from sqlalchemy.orm import Session
from sqlalchemy.sql import quoted_name
from langchain_community.vectorstores.utils import maximal_marginal_relevance
try:
from sqlalchemy.orm import declarative_base
except ImportError:
from sqlalchemy.ext.declarative import declarative_base
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.utils import get_from_dict_or_env
from langchain_core.vectorstores import VectorStore
ADA_TOKEN_COUNT = 1536
_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain"
def _results_to_docs(docs_and_scores: Any) -> List[Document]:
"""从文档和分数中返回文档。"""
return [doc for doc, _ in docs_and_scores]
[docs]class BaseEmbeddingStore:
"""用于灯笼嵌入存储的基类。"""
[docs]def get_embedding_store(
distance_strategy: DistanceStrategy, collection_name: str
) -> Any:
"""获取嵌入式存储类。"""
embedding_type = None
if distance_strategy == DistanceStrategy.HAMMING:
embedding_type = sqlalchemy.INTEGER # type: ignore
else:
embedding_type = sqlalchemy.REAL # type: ignore
DynamicBase = declarative_base(class_registry=dict()) # type: Any
class EmbeddingStore(DynamicBase, BaseEmbeddingStore):
__tablename__ = collection_name
uuid = sqlalchemy.Column(
UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
__table_args__ = {"extend_existing": True}
document = sqlalchemy.Column(sqlalchemy.String, nullable=True)
cmetadata = sqlalchemy.Column(JSON, nullable=True)
# custom_id : any user defined id
custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True)
embedding = sqlalchemy.Column(sqlalchemy.ARRAY(embedding_type)) # type: ignore
return EmbeddingStore
[docs]class QueryResult:
"""查询结果。"""
EmbeddingStore: BaseEmbeddingStore
distance: float
[docs]class DistanceStrategy(str, enum.Enum):
"""距离策略的枚举器。"""
EUCLIDEAN = "l2sq"
COSINE = "cosine"
HAMMING = "hamming"
DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.COSINE
[docs]class Lantern(VectorStore):
"""`Postgres`与`lantern`扩展作为向量存储。
lantern默认使用顺序扫描。但是你可以使用create_hnsw_index方法创建HNSW索引。
- `connection_string`是一个postgres连接字符串。
- `embedding_function`是实现`langchain.embeddings.base.Embeddings`接口的任何嵌入函数。
- `collection_name`是要使用的集合的名称。(默认值:langchain)
- 注意:这是嵌入数据将被存储的表的名称
表将在初始化存储时创建(如果不存在)
因此,请确保用户有创建表的权限。
- `distance_strategy`是要使用的距离策略。(默认值:EUCLIDEAN)
- `EUCLIDEAN`是欧氏距离。
- `COSINE`是余弦距离。
- `HAMMING`是汉明距离。
- `pre_delete_collection`如果为True,将删除集合(如果存在)。
(默认值:False)
- 用于测试。"""
[docs] def __init__(
self,
connection_string: str,
embedding_function: Embeddings,
distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
collection_metadata: Optional[dict] = None,
pre_delete_collection: bool = False,
logger: Optional[logging.Logger] = None,
relevance_score_fn: Optional[Callable[[float], float]] = None,
) -> None:
self.connection_string = connection_string
self.embedding_function = embedding_function
self.collection_name = collection_name
self.collection_metadata = collection_metadata
self._distance_strategy = distance_strategy
self.pre_delete_collection = pre_delete_collection
self.logger = logger or logging.getLogger(__name__)
self.override_relevance_score_fn = relevance_score_fn
self.EmbeddingStore = get_embedding_store(
self.distance_strategy, collection_name
)
self.__post_init__()
def __post_init__(
self,
) -> None:
self._conn = self.connect()
self.create_hnsw_extension()
self.create_collection()
@property
def distance_strategy(self) -> DistanceStrategy:
if isinstance(self._distance_strategy, DistanceStrategy):
return self._distance_strategy
if self._distance_strategy == DistanceStrategy.EUCLIDEAN.value:
return DistanceStrategy.EUCLIDEAN
elif self._distance_strategy == DistanceStrategy.COSINE.value:
return DistanceStrategy.COSINE
elif self._distance_strategy == DistanceStrategy.HAMMING.value:
return DistanceStrategy.HAMMING
else:
raise ValueError(
f"Got unexpected value for distance: {self._distance_strategy}. "
f"Should be one of {', '.join([ds.value for ds in DistanceStrategy])}."
)
@property
def embeddings(self) -> Embeddings:
return self.embedding_function
[docs] @classmethod
def connection_string_from_db_params(
cls,
driver: str,
host: str,
port: int,
database: str,
user: str,
password: str,
) -> str:
"""从数据库参数返回连接字符串。"""
return f"postgresql+{driver}://{user}:{password}@{host}:{port}/{database}"
[docs] def connect(self) -> sqlalchemy.engine.Connection:
engine = sqlalchemy.create_engine(self.connection_string)
conn = engine.connect()
return conn
@property
def distance_function(self) -> Any:
if self.distance_strategy == DistanceStrategy.EUCLIDEAN:
return "l2sq_dist"
elif self.distance_strategy == DistanceStrategy.COSINE:
return "cos_dist"
elif self.distance_strategy == DistanceStrategy.HAMMING:
return "hamming_dist"
[docs] def create_hnsw_extension(self) -> None:
try:
with Session(self._conn) as session:
statement = sqlalchemy.text("CREATE EXTENSION IF NOT EXISTS lantern")
session.execute(statement)
session.commit()
except Exception as e:
self.logger.exception(e)
[docs] def create_tables_if_not_exists(self) -> None:
try:
self.create_collection()
except ProgrammingError:
pass
[docs] def drop_table(self) -> None:
try:
self.EmbeddingStore.__table__.drop(self._conn.engine)
except ProgrammingError:
pass
[docs] def drop_tables(self) -> None:
self.drop_table()
def _hamming_relevance_score_fn(self, distance: float) -> float:
return distance
def _select_relevance_score_fn(self) -> Callable[[float], float]:
"""“正确”的相关性函数可能会有所不同,取决于一些因素,包括:
- 向量存储中使用的距离/相似度度量
- 嵌入的规模(OpenAI的是单位规范化的。许多其他嵌入不是!)
- 嵌入的维度
- 等等。
"""
if self.override_relevance_score_fn is not None:
return self.override_relevance_score_fn
# Default strategy is to rely on distance strategy provided
# in vectorstore constructor
if self.distance_strategy == DistanceStrategy.COSINE:
return self._cosine_relevance_score_fn
elif self.distance_strategy == DistanceStrategy.EUCLIDEAN:
return self._euclidean_relevance_score_fn
elif self.distance_strategy == DistanceStrategy.HAMMING:
return self._hamming_relevance_score_fn
else:
raise ValueError(
"No supported normalization function"
f" for distance_strategy of {self._distance_strategy}."
"Consider providing relevance_score_fn to Lantern constructor."
)
def _get_op_class(self) -> str:
if self.distance_strategy == DistanceStrategy.COSINE:
return "dist_cos_ops"
elif self.distance_strategy == DistanceStrategy.EUCLIDEAN:
return "dist_l2sq_ops"
elif self.distance_strategy == DistanceStrategy.HAMMING:
return "dist_hamming_ops"
else:
raise ValueError(
"No supported operator class"
f" for distance_strategy of {self._distance_strategy}."
)
def _get_operator(self) -> str:
if self.distance_strategy == DistanceStrategy.COSINE:
return "<=>"
elif self.distance_strategy == DistanceStrategy.EUCLIDEAN:
return "<->"
elif self.distance_strategy == DistanceStrategy.HAMMING:
return "<+>"
else:
raise ValueError(
"No supported operator"
f" for distance_strategy of {self._distance_strategy}."
)
def _typed_arg_for_distance(
self, embedding: List[Union[float, int]]
) -> List[Union[float, int]]:
if self.distance_strategy == DistanceStrategy.HAMMING:
return list(map(lambda x: int(x), embedding))
return embedding
@property
def _index_name(self) -> str:
return f"langchain_{self.collection_name}_idx"
[docs] def create_hnsw_index(
self,
dims: int = ADA_TOKEN_COUNT,
m: int = 16,
ef_construction: int = 64,
ef_search: int = 64,
**_kwargs: Any,
) -> None:
"""在集合上创建HNSW索引。
HNSW索引的可选关键字参数:
engine: "nmslib", "faiss", "lucene"; 默认值: "nmslib"
ef: k-NN搜索期间使用的动态列表的大小。较高的值会导致更准确但更慢的搜索;默认值: 64
ef_construction: k-NN图创建期间使用的动态列表的大小。较高的值会导致更准确的图形但索引速度较慢;默认值: 64
m: 每个新元素创建的双向链接数。对内存消耗有很大影响。介于2和100之间;默认值: 16
dims: 集合中向量的维度。默认值: 1536
"""
create_index_query = sqlalchemy.text(
"CREATE INDEX IF NOT EXISTS {} "
"ON {} USING hnsw (embedding {}) "
"WITH ("
"dim = :dim, "
"m = :m, "
"ef_construction = :ef_construction, "
"ef = :ef"
");".format(
quoted_name(self._index_name, True),
quoted_name(self.collection_name, True),
self._get_op_class(),
)
)
with Session(self._conn) as session:
# Create the HNSW index
session.execute(
create_index_query,
{
"dim": dims,
"m": m,
"ef_construction": ef_construction,
"ef": ef_search,
},
)
session.commit()
self.logger.info("HNSW extension and index created successfully.")
[docs] def drop_index(self) -> None:
with Session(self._conn) as session:
# Drop the HNSW index
session.execute(
sqlalchemy.text(
"DROP INDEX IF EXISTS {}".format(
quoted_name(self._index_name, True)
)
)
)
session.commit()
[docs] def create_collection(self) -> None:
if self.pre_delete_collection:
self.delete_collection()
self.drop_table()
with self._conn.begin():
try:
self.EmbeddingStore.__table__.create(self._conn.engine)
except ProgrammingError as e:
# Duplicate table
if e.code == "f405":
pass
else:
raise e
[docs] def delete_collection(self) -> None:
self.logger.debug("Trying to delete collection")
self.drop_table()
@contextlib.contextmanager
def _make_session(self) -> Generator[Session, None, None]:
"""为会话创建一个上下文管理器,绑定到_conn字符串。"""
yield Session(self._conn)
[docs] def delete(
self,
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""根据id或uuid删除向量。
参数:
ids:要删除的id列表。
"""
with Session(self._conn) as session:
if ids is not None:
self.logger.debug(
"Trying to delete vectors by ids (represented by the model "
"using the custom ids field)"
)
stmt = delete(self.EmbeddingStore).where(
self.EmbeddingStore.custom_id.in_(ids)
)
session.execute(stmt)
session.commit()
@classmethod
def _initialize_from_embeddings(
cls,
texts: List[str],
embeddings: List[List[float]],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
pre_delete_collection: bool = False,
**kwargs: Any,
) -> Lantern:
"""列表`ids`,`embeddings`,`texts`,`metadatas`的元素顺序应该匹配,这样每一行将与正确的值关联。
需要Postgres连接字符串
"可以将其作为`connection_string`参数传递
或设置LANTERN_CONNECTION_STRING环境变量。
- `texts` 要插入到集合中的文本。
- `embeddings` 要插入到集合中的嵌入。
- `embedding` 是要用于嵌入发送的文本的:class:`Embeddings`。
如果未发送任何内容,则将使用多语言Tensorflow通用句子编码器。
- `metadatas` 要插入到集合中的行元数据。
- `ids` 要插入到集合中的行ID。
- `collection_name` 是要使用的集合的名称(默认值:langchain)。
- 注意:这是嵌入数据将被存储的表的名称
表将在初始化存储时创建(如果不存在)
因此,请确保用户具有创建表的权限。
- `distance_strategy` 是要使用的距离策略(默认值:EUCLIDEAN)。
- `EUCLIDEAN` 是欧几里德距离。
- `COSINE` 是余弦距禂。
- `HAMMING` 是汉明距离。
- `pre_delete_collection` 如果为True,则如果存在,将删除集合。
(默认值:False)
- 用于测试。
"""
if ids is None:
ids = [str(uuid.uuid4()) for _ in texts]
if not metadatas:
metadatas = [{} for _ in texts]
connection_string = cls.__get_connection_string(kwargs)
store = cls(
connection_string=connection_string,
collection_name=collection_name,
embedding_function=embedding,
pre_delete_collection=pre_delete_collection,
distance_strategy=distance_strategy,
)
store.add_embeddings(
texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs
)
store.create_hnsw_index(**kwargs)
return store
[docs] def add_embeddings(
self,
texts: List[str],
embeddings: List[List[float]],
metadatas: List[dict],
ids: List[str],
**kwargs: Any,
) -> None:
with Session(self._conn) as session:
for text, metadata, embedding, id in zip(texts, metadatas, embeddings, ids):
embedding_store = self.EmbeddingStore(
embedding=embedding,
document=text,
cmetadata=metadata,
custom_id=id,
)
session.add(embedding_store)
session.commit()
[docs] def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> List[str]:
if ids is None:
ids = [str(uuid.uuid4()) for _ in texts]
embeddings = self.embedding_function.embed_documents(list(texts))
if not metadatas:
metadatas = [{} for _ in texts]
with Session(self._conn) as session:
for text, metadata, embedding, id in zip(texts, metadatas, embeddings, ids):
embedding_store = self.EmbeddingStore(
embedding=embedding,
document=text,
cmetadata=metadata,
custom_id=id,
)
session.add(embedding_store)
session.commit()
return ids
def _results_to_docs_and_scores(self, results: Any) -> List[Tuple[Document, float]]:
"""返回结果中的文档和分数。"""
docs = [
(
Document(
page_content=result.EmbeddingStore.document,
metadata=result.EmbeddingStore.cmetadata,
),
result.distance if self.embedding_function is not None else None,
)
for result in results
]
return docs
[docs] def similarity_search(
self,
query: str,
k: int = 4,
filter: Optional[dict] = None,
**kwargs: Any,
) -> List[Document]:
embedding = self.embedding_function.embed_query(text=query)
return self.similarity_search_by_vector(
embedding=embedding,
k=k,
filter=filter,
)
[docs] def similarity_search_with_score(
self,
query: str,
k: int = 4,
filter: Optional[dict] = None,
) -> List[Tuple[Document, float]]:
embedding = self.embedding_function.embed_query(query)
docs = self.similarity_search_with_score_by_vector(
embedding=embedding, k=k, filter=filter
)
return docs
[docs] def similarity_search_with_score_by_vector(
self,
embedding: List[float],
k: int = 4,
filter: Optional[dict] = None,
) -> List[Tuple[Document, float]]:
results = self.__query_collection(embedding=embedding, k=k, filter=filter)
return self._results_to_docs_and_scores(results)
def __query_collection(
self,
embedding: List[float],
k: int = 4,
filter: Optional[dict] = None,
) -> List[Any]:
with Session(self._conn) as session:
set_enable_seqscan_stmt = sqlalchemy.text("SET enable_seqscan = off")
set_init_k = sqlalchemy.text("SET hnsw.init_k = :k")
session.execute(set_enable_seqscan_stmt)
session.execute(set_init_k, {"k": k})
filter_by = None
if filter is not None:
filter_clauses = []
for key, value in filter.items():
IN = "in"
if isinstance(value, dict) and IN in map(str.lower, value):
value_case_insensitive = {
k.lower(): v for k, v in value.items()
}
filter_by_metadata = self.EmbeddingStore.cmetadata[
key
].astext.in_(value_case_insensitive[IN])
filter_clauses.append(filter_by_metadata)
else:
filter_by_metadata = self.EmbeddingStore.cmetadata[
key
].astext == str(value)
filter_clauses.append(filter_by_metadata)
filter_by = sqlalchemy.and_(*filter_clauses)
embedding = self._typed_arg_for_distance(embedding)
query = session.query(
self.EmbeddingStore,
getattr(func, self.distance_function)(
self.EmbeddingStore.embedding, embedding
).label("distance"),
) # Specify the columns you need here, e.g., EmbeddingStore.embedding
if filter_by is not None:
query = query.filter(filter_by)
results: List[QueryResult] = (
query.order_by(
self.EmbeddingStore.embedding.op(self._get_operator())(embedding)
) # Using PostgreSQL specific operator with the correct column name
.limit(k)
.all()
)
return results
[docs] def similarity_search_by_vector(
self,
embedding: List[float],
k: int = 4,
filter: Optional[dict] = None,
**kwargs: Any,
) -> List[Document]:
docs_and_scores = self.similarity_search_with_score_by_vector(
embedding=embedding, k=k, filter=filter
)
return _results_to_docs(docs_and_scores)
[docs] @classmethod
def from_texts(
cls: Type[Lantern],
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
ids: Optional[List[str]] = None,
pre_delete_collection: bool = False,
**kwargs: Any,
) -> Lantern:
"""从文本列表初始化Lantern向量存储。
将使用提供的`embedding`类生成嵌入。
列表`ids`、`texts`、`metadatas`中元素的顺序应该匹配,
这样每一行将与正确的值关联。
需要Postgres连接字符串
"可以将其作为`connection_string`参数传递
或设置LANTERN_CONNECTION_STRING环境变量。
- `connection_string` 是用于postgres数据库的完全填充的连接字符串
- `texts` 要插入到集合中的文本。
- `embedding` 是将用于嵌入发送的文本的:class:`Embeddings`。
如果未发送任何内容,则将使用多语言Tensorflow通用句子编码器。
- `metadatas` 要插入到集合中的行元数据。
- `collection_name` 是要使用的集合的名称。(默认值:langchain)
- 注意:这是嵌入数据将被存储的表的名称
表将在初始化存储时创建(如果不存在)
因此,请确保用户具有创建表的权限。
- `distance_strategy` 是要使用的距离策略。(默认值:EUCLIDEAN)
- `EUCLIDEAN` 是欧几里德距离。
- `COSINE` 是余弦距离。
- `HAMMING` 是汉明距离。
- `ids` 要插入到集合中的行ID。
- `pre_delete_collection` 如果为True,则将删除集合(如果存在)。
(默认值:False)
- 用于测试。
"""
embeddings = embedding.embed_documents(list(texts))
return cls._initialize_from_embeddings(
texts,
embeddings,
embedding,
metadatas=metadatas,
ids=ids,
collection_name=collection_name,
pre_delete_collection=pre_delete_collection,
distance_strategy=distance_strategy,
**kwargs,
)
[docs] @classmethod
def from_embeddings(
cls,
text_embeddings: List[Tuple[str, List[float]]],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
ids: Optional[List[str]] = None,
pre_delete_collection: bool = False,
distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
**kwargs: Any,
) -> Lantern:
"""从原始文档和预生成的嵌入中构建 Lantern 包装器。
需要Postgres连接字符串
"可以将其作为 `connection_string` 参数传递
或设置 LANTERN_CONNECTION_STRING 环境变量。
列表 `ids`、`text_embeddings`、`metadatas` 的元素顺序应该匹配,
这样每一行将与正确的值关联。
- `connection_string` 是用于 postgres 数据库的完全填充的连接字符串
- `text_embeddings` 是包含元组 (text, embedding) 的数组
用于插入到集合中。
- `embedding` 是将用于嵌入发送的文本的 :class:`Embeddings`。
如果未发送任何内容,则将使用多语言 Tensorflow Universal Sentence Encoder。
- `metadatas` 行元数据,用于插入到集合中。
- `collection_name` 是要使用的集合的名称。 (默认值: langchain)
- 注意: 这是嵌入数据将被存储的表的名称
初始化存储时将创建该表(如果不存在)
因此,请确保用户具有创建表的权限。
- `ids` 要插入到集合中的行 ids。
- `pre_delete_collection` 如果为 True,则将删除该集合(如果存在)。
(默认值: False)
- 用于测试。
- `distance_strategy` 是要使用的距离策略。 (默认值: EUCLIDEAN)
- `EUCLIDEAN` 是欧几里德距离。
- `COSINE` 是余弦距禮。
- `HAMMING` 是汉明距离。
"""
texts = [t[0] for t in text_embeddings]
embeddings = [t[1] for t in text_embeddings]
return cls._initialize_from_embeddings(
texts,
embeddings,
embedding,
metadatas=metadatas,
ids=ids,
collection_name=collection_name,
pre_delete_collection=pre_delete_collection,
distance_strategy=distance_strategy,
**kwargs,
)
[docs] @classmethod
def from_existing_index(
cls: Type[Lantern],
embedding: Embeddings,
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
pre_delete_collection: bool = False,
distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
**kwargs: Any,
) -> Lantern:
"""获取现有Lantern存储库的实例。此方法将返回存储库的实例,而不会插入任何新的嵌入。
需要Postgres连接字符串
"可以将其作为`connection_string`参数传递
或设置LANTERN_CONNECTION_STRING环境变量。
- `connection_string`是一个Postgres连接字符串。
- `embedding`是将用于嵌入发送的文本的:class:`Embeddings`。如果未发送任何内容,则将使用多语言Tensorflow通用句子编码器。
- `collection_name`是要使用的集合的名称。(默认值:langchain)
- 注意:这是嵌入数据将被存储的表的名称
表将在初始化存储时创建(如果不存在)
因此,请确保用户具有创建表的权限。
- `ids`要插入到集合中的行ID。
- `pre_delete_collection`如果为True,则将删除该集合(如果存在)。
(默认值:False)
- 用于测试。
- `distance_strategy`是要使用的距离策略。(默认值:EUCLIDEAN)
- `EUCLIDEAN`是欧氏距离。
- `COSINE`是余弦距禂。
- `HAMMING`是汉明距离。
"""
connection_string = cls.__get_connection_string(kwargs)
store = cls(
connection_string=connection_string,
collection_name=collection_name,
embedding_function=embedding,
pre_delete_collection=pre_delete_collection,
distance_strategy=distance_strategy,
)
return store
@classmethod
def __get_connection_string(cls, kwargs: Dict[str, Any]) -> str:
connection_string: str = get_from_dict_or_env(
data=kwargs,
key="connection_string",
env_key="LANTERN_CONNECTION_STRING",
)
if not connection_string:
raise ValueError(
"Postgres connection string is required"
"Either pass it as `connection_string` parameter"
"or set the LANTERN_CONNECTION_STRING variable."
)
return connection_string
[docs] @classmethod
def from_documents(
cls: Type[Lantern],
documents: List[Document],
embedding: Embeddings,
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
ids: Optional[List[str]] = None,
pre_delete_collection: bool = False,
**kwargs: Any,
) -> Lantern:
"""初始化一个包含一组文档的向量存储。
需要Postgres连接字符串
"可以通过`connection_string`参数传递
或设置LANTERN_CONNECTION_STRING环境变量。
- `connection_string`是一个Postgres连接字符串。
- `documents`是要初始化向量存储的:class:`Document`列表
- `embedding`是将用于嵌入发送的文本的:class:`Embeddings`。如果未发送任何内容,则将使用多语言Tensorflow通用句子编码器。
- `collection_name`是要使用的集合名称。(默认值:langchain)
- 注意:这是将存储嵌入数据的表的名称
初始化存储时将创建该表(如果不存在)
因此,请确保用户具有创建表的权限。
- `distance_strategy`是要使用的距离策略。(默认值:EUCLIDEAN)
- `EUCLIDEAN`是欧几里得距离。
- `COSINE`是余弦距离。
- `HAMMING`是汉明距离。
- `ids`要插入到集合中的行ID。
- `pre_delete_collection`如果为True,则将删除集合(如果存在)。
(默认值:False)
- 用于测试。
"""
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]
connection_string = cls.__get_connection_string(kwargs)
kwargs["connection_string"] = connection_string
return cls.from_texts(
texts=texts,
pre_delete_collection=pre_delete_collection,
embedding=embedding,
metadatas=metadatas,
ids=ids,
collection_name=collection_name,
distance_strategy=distance_strategy,
**kwargs,
)
[docs] def max_marginal_relevance_search_with_score_by_vector(
self,
embedding: List[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, str]] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""使用最大边际相关性和分数返回所选文档的文档。
最大边际相关性优化了与查询的相似性和所选文档之间的多样性。
参数:
embedding: 要查找相似文档的嵌入。
k (int): 要返回的文档数量。默认为4。
fetch_k (int): 要获取以传递给MMR算法的文档数量。默认为20。
lambda_mult (float): 0到1之间的数字,确定结果之间多样性的程度,其中0对应最大多样性,1对应最小多样性。默认为0.5。
filter (Optional[Dict[str, str]]): 按元数据筛选。默认为None。
返回:
List[Tuple[Document, float]]: 通过最大边际相关性选择的文档列表,以及每个文档的得分。
"""
results = self.__query_collection(embedding=embedding, k=fetch_k, filter=filter)
embedding_list = [result.EmbeddingStore.embedding for result in results]
mmr_selected = maximal_marginal_relevance(
np.array(embedding, dtype=np.float32),
embedding_list,
k=k,
lambda_mult=lambda_mult,
)
candidates = self._results_to_docs_and_scores(results)
return [r for i, r in enumerate(candidates) if i in mmr_selected]
[docs] def max_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, str]] = None,
**kwargs: Any,
) -> List[Document]:
"""返回使用最大边际相关性选择的文档。
最大边际相关性优化了与查询的相似性和所选文档之间的多样性。
参数:
query (str): 要查找类似文档的文本。
k (int): 要返回的文档数量。默认为4。
fetch_k (int): 要获取以传递给MMR算法的文档数量。
默认为20。
lambda_mult (float): 介于0和1之间的数字,确定结果之间多样性的程度,
其中0对应于最大多样性,1对应于最小多样性。
默认为0.5。
filter (Optional[Dict[str, str]]): 按元数据筛选。默认为None。
返回:
List[Document]: 通过最大边际相关性选择的文档列表。
"""
embedding = self.embedding_function.embed_query(query)
return self.max_marginal_relevance_search_by_vector(
embedding,
k=k,
fetch_k=fetch_k,
lambda_mult=lambda_mult,
filter=filter,
**kwargs,
)
[docs] def max_marginal_relevance_search_with_score(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[dict] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""返回使用最大边际相关性和分数选择的文档。
最大边际相关性优化了与查询的相似性和所选文档之间的多样性。
参数:
query (str): 要查找类似文档的文本。
k (int): 要返回的文档数量。默认为4。
fetch_k (int): 要获取以传递给MMR算法的文档数量。
默认为20。
lambda_mult (float): 0到1之间的数字,确定结果之间多样性的程度,
0对应最大多样性,1对应最小多样性。
默认为0.5。
filter (Optional[Dict[str, str]]): 按元数据过滤。默认为None。
返回:
List[Tuple[Document, float]]: 通过最大边际相关性选择的文档列表,
以及每个文档的得分。
"""
embedding = self.embedding_function.embed_query(query)
docs = self.max_marginal_relevance_search_with_score_by_vector(
embedding=embedding,
k=k,
fetch_k=fetch_k,
lambda_mult=lambda_mult,
filter=filter,
**kwargs,
)
return docs
[docs] def max_marginal_relevance_search_by_vector(
self,
embedding: List[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, str]] = None,
**kwargs: Any,
) -> List[Document]:
"""返回使用最大边际相关性选择的文档到嵌入向量。
最大边际相关性优化了与查询的相似性和所选文档之间的多样性。
参数:
embedding (str): 要查找类似文档的文本。
k (int): 要返回的文档数量。默认为4。
fetch_k (int): 要获取以传递给MMR算法的文档数量。
默认为20。
lambda_mult (float): 0到1之间的数字,确定结果之间多样性的程度,
0对应最大多样性,1对应最小多样性。
默认为0.5。
filter (Optional[Dict[str, str]]): 按元数据筛选。默认为None。
返回:
List[Document]: 通过最大边际相关性选择的文档列表。
"""
docs_and_scores = self.max_marginal_relevance_search_with_score_by_vector(
embedding,
k=k,
fetch_k=fetch_k,
lambda_mult=lambda_mult,
filter=filter,
**kwargs,
)
return _results_to_docs(docs_and_scores)