Source code for langchain_community.vectorstores.lantern

from __future__ import annotations

import contextlib
import enum
import logging
import uuid
from typing import (
    Any,
    Callable,
    Dict,
    Generator,
    Iterable,
    List,
    Optional,
    Tuple,
    Type,
    Union,
)

import numpy as np
import sqlalchemy
from sqlalchemy import delete, func
from sqlalchemy.dialects.postgresql import JSON, UUID
from sqlalchemy.exc import ProgrammingError
from sqlalchemy.orm import Session
from sqlalchemy.sql import quoted_name

from langchain_community.vectorstores.utils import maximal_marginal_relevance

try:
    from sqlalchemy.orm import declarative_base
except ImportError:
    from sqlalchemy.ext.declarative import declarative_base

from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.utils import get_from_dict_or_env
from langchain_core.vectorstores import VectorStore

ADA_TOKEN_COUNT = 1536
_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain"


def _results_to_docs(docs_and_scores: Any) -> List[Document]:
    """从文档和分数中返回文档。"""
    return [doc for doc, _ in docs_and_scores]


[docs]class BaseEmbeddingStore: """用于灯笼嵌入存储的基类。"""
[docs]def get_embedding_store( distance_strategy: DistanceStrategy, collection_name: str ) -> Any: """获取嵌入式存储类。""" embedding_type = None if distance_strategy == DistanceStrategy.HAMMING: embedding_type = sqlalchemy.INTEGER # type: ignore else: embedding_type = sqlalchemy.REAL # type: ignore DynamicBase = declarative_base(class_registry=dict()) # type: Any class EmbeddingStore(DynamicBase, BaseEmbeddingStore): __tablename__ = collection_name uuid = sqlalchemy.Column( UUID(as_uuid=True), primary_key=True, default=uuid.uuid4 ) __table_args__ = {"extend_existing": True} document = sqlalchemy.Column(sqlalchemy.String, nullable=True) cmetadata = sqlalchemy.Column(JSON, nullable=True) # custom_id : any user defined id custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True) embedding = sqlalchemy.Column(sqlalchemy.ARRAY(embedding_type)) # type: ignore return EmbeddingStore
[docs]class QueryResult: """查询结果。""" EmbeddingStore: BaseEmbeddingStore distance: float
[docs]class DistanceStrategy(str, enum.Enum): """距离策略的枚举器。""" EUCLIDEAN = "l2sq" COSINE = "cosine" HAMMING = "hamming"
DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.COSINE
[docs]class Lantern(VectorStore): """`Postgres`与`lantern`扩展作为向量存储。 lantern默认使用顺序扫描。但是你可以使用create_hnsw_index方法创建HNSW索引。 - `connection_string`是一个postgres连接字符串。 - `embedding_function`是实现`langchain.embeddings.base.Embeddings`接口的任何嵌入函数。 - `collection_name`是要使用的集合的名称。(默认值:langchain) - 注意:这是嵌入数据将被存储的表的名称 表将在初始化存储时创建(如果不存在) 因此,请确保用户有创建表的权限。 - `distance_strategy`是要使用的距离策略。(默认值:EUCLIDEAN) - `EUCLIDEAN`是欧氏距离。 - `COSINE`是余弦距离。 - `HAMMING`是汉明距离。 - `pre_delete_collection`如果为True,将删除集合(如果存在)。 (默认值:False) - 用于测试。"""
[docs] def __init__( self, connection_string: str, embedding_function: Embeddings, distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, collection_metadata: Optional[dict] = None, pre_delete_collection: bool = False, logger: Optional[logging.Logger] = None, relevance_score_fn: Optional[Callable[[float], float]] = None, ) -> None: self.connection_string = connection_string self.embedding_function = embedding_function self.collection_name = collection_name self.collection_metadata = collection_metadata self._distance_strategy = distance_strategy self.pre_delete_collection = pre_delete_collection self.logger = logger or logging.getLogger(__name__) self.override_relevance_score_fn = relevance_score_fn self.EmbeddingStore = get_embedding_store( self.distance_strategy, collection_name ) self.__post_init__()
def __post_init__( self, ) -> None: self._conn = self.connect() self.create_hnsw_extension() self.create_collection() @property def distance_strategy(self) -> DistanceStrategy: if isinstance(self._distance_strategy, DistanceStrategy): return self._distance_strategy if self._distance_strategy == DistanceStrategy.EUCLIDEAN.value: return DistanceStrategy.EUCLIDEAN elif self._distance_strategy == DistanceStrategy.COSINE.value: return DistanceStrategy.COSINE elif self._distance_strategy == DistanceStrategy.HAMMING.value: return DistanceStrategy.HAMMING else: raise ValueError( f"Got unexpected value for distance: {self._distance_strategy}. " f"Should be one of {', '.join([ds.value for ds in DistanceStrategy])}." ) @property def embeddings(self) -> Embeddings: return self.embedding_function
[docs] @classmethod def connection_string_from_db_params( cls, driver: str, host: str, port: int, database: str, user: str, password: str, ) -> str: """从数据库参数返回连接字符串。""" return f"postgresql+{driver}://{user}:{password}@{host}:{port}/{database}"
[docs] def connect(self) -> sqlalchemy.engine.Connection: engine = sqlalchemy.create_engine(self.connection_string) conn = engine.connect() return conn
@property def distance_function(self) -> Any: if self.distance_strategy == DistanceStrategy.EUCLIDEAN: return "l2sq_dist" elif self.distance_strategy == DistanceStrategy.COSINE: return "cos_dist" elif self.distance_strategy == DistanceStrategy.HAMMING: return "hamming_dist"
[docs] def create_hnsw_extension(self) -> None: try: with Session(self._conn) as session: statement = sqlalchemy.text("CREATE EXTENSION IF NOT EXISTS lantern") session.execute(statement) session.commit() except Exception as e: self.logger.exception(e)
[docs] def create_tables_if_not_exists(self) -> None: try: self.create_collection() except ProgrammingError: pass
[docs] def drop_table(self) -> None: try: self.EmbeddingStore.__table__.drop(self._conn.engine) except ProgrammingError: pass
[docs] def drop_tables(self) -> None: self.drop_table()
def _hamming_relevance_score_fn(self, distance: float) -> float: return distance def _select_relevance_score_fn(self) -> Callable[[float], float]: """“正确”的相关性函数可能会有所不同,取决于一些因素,包括: - 向量存储中使用的距离/相似度度量 - 嵌入的规模(OpenAI的是单位规范化的。许多其他嵌入不是!) - 嵌入的维度 - 等等。 """ if self.override_relevance_score_fn is not None: return self.override_relevance_score_fn # Default strategy is to rely on distance strategy provided # in vectorstore constructor if self.distance_strategy == DistanceStrategy.COSINE: return self._cosine_relevance_score_fn elif self.distance_strategy == DistanceStrategy.EUCLIDEAN: return self._euclidean_relevance_score_fn elif self.distance_strategy == DistanceStrategy.HAMMING: return self._hamming_relevance_score_fn else: raise ValueError( "No supported normalization function" f" for distance_strategy of {self._distance_strategy}." "Consider providing relevance_score_fn to Lantern constructor." ) def _get_op_class(self) -> str: if self.distance_strategy == DistanceStrategy.COSINE: return "dist_cos_ops" elif self.distance_strategy == DistanceStrategy.EUCLIDEAN: return "dist_l2sq_ops" elif self.distance_strategy == DistanceStrategy.HAMMING: return "dist_hamming_ops" else: raise ValueError( "No supported operator class" f" for distance_strategy of {self._distance_strategy}." ) def _get_operator(self) -> str: if self.distance_strategy == DistanceStrategy.COSINE: return "<=>" elif self.distance_strategy == DistanceStrategy.EUCLIDEAN: return "<->" elif self.distance_strategy == DistanceStrategy.HAMMING: return "<+>" else: raise ValueError( "No supported operator" f" for distance_strategy of {self._distance_strategy}." ) def _typed_arg_for_distance( self, embedding: List[Union[float, int]] ) -> List[Union[float, int]]: if self.distance_strategy == DistanceStrategy.HAMMING: return list(map(lambda x: int(x), embedding)) return embedding @property def _index_name(self) -> str: return f"langchain_{self.collection_name}_idx"
[docs] def create_hnsw_index( self, dims: int = ADA_TOKEN_COUNT, m: int = 16, ef_construction: int = 64, ef_search: int = 64, **_kwargs: Any, ) -> None: """在集合上创建HNSW索引。 HNSW索引的可选关键字参数: engine: "nmslib", "faiss", "lucene"; 默认值: "nmslib" ef: k-NN搜索期间使用的动态列表的大小。较高的值会导致更准确但更慢的搜索;默认值: 64 ef_construction: k-NN图创建期间使用的动态列表的大小。较高的值会导致更准确的图形但索引速度较慢;默认值: 64 m: 每个新元素创建的双向链接数。对内存消耗有很大影响。介于2和100之间;默认值: 16 dims: 集合中向量的维度。默认值: 1536 """ create_index_query = sqlalchemy.text( "CREATE INDEX IF NOT EXISTS {} " "ON {} USING hnsw (embedding {}) " "WITH (" "dim = :dim, " "m = :m, " "ef_construction = :ef_construction, " "ef = :ef" ");".format( quoted_name(self._index_name, True), quoted_name(self.collection_name, True), self._get_op_class(), ) ) with Session(self._conn) as session: # Create the HNSW index session.execute( create_index_query, { "dim": dims, "m": m, "ef_construction": ef_construction, "ef": ef_search, }, ) session.commit() self.logger.info("HNSW extension and index created successfully.")
[docs] def drop_index(self) -> None: with Session(self._conn) as session: # Drop the HNSW index session.execute( sqlalchemy.text( "DROP INDEX IF EXISTS {}".format( quoted_name(self._index_name, True) ) ) ) session.commit()
[docs] def create_collection(self) -> None: if self.pre_delete_collection: self.delete_collection() self.drop_table() with self._conn.begin(): try: self.EmbeddingStore.__table__.create(self._conn.engine) except ProgrammingError as e: # Duplicate table if e.code == "f405": pass else: raise e
[docs] def delete_collection(self) -> None: self.logger.debug("Trying to delete collection") self.drop_table()
@contextlib.contextmanager def _make_session(self) -> Generator[Session, None, None]: """为会话创建一个上下文管理器,绑定到_conn字符串。""" yield Session(self._conn)
[docs] def delete( self, ids: Optional[List[str]] = None, **kwargs: Any, ) -> None: """根据id或uuid删除向量。 参数: ids:要删除的id列表。 """ with Session(self._conn) as session: if ids is not None: self.logger.debug( "Trying to delete vectors by ids (represented by the model " "using the custom ids field)" ) stmt = delete(self.EmbeddingStore).where( self.EmbeddingStore.custom_id.in_(ids) ) session.execute(stmt) session.commit()
@classmethod def _initialize_from_embeddings( cls, texts: List[str], embeddings: List[List[float]], embedding: Embeddings, metadatas: Optional[List[dict]] = None, ids: Optional[List[str]] = None, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, pre_delete_collection: bool = False, **kwargs: Any, ) -> Lantern: """列表`ids`,`embeddings`,`texts`,`metadatas`的元素顺序应该匹配,这样每一行将与正确的值关联。 需要Postgres连接字符串 "可以将其作为`connection_string`参数传递 或设置LANTERN_CONNECTION_STRING环境变量。 - `texts` 要插入到集合中的文本。 - `embeddings` 要插入到集合中的嵌入。 - `embedding` 是要用于嵌入发送的文本的:class:`Embeddings`。 如果未发送任何内容,则将使用多语言Tensorflow通用句子编码器。 - `metadatas` 要插入到集合中的行元数据。 - `ids` 要插入到集合中的行ID。 - `collection_name` 是要使用的集合的名称(默认值:langchain)。 - 注意:这是嵌入数据将被存储的表的名称 表将在初始化存储时创建(如果不存在) 因此,请确保用户具有创建表的权限。 - `distance_strategy` 是要使用的距离策略(默认值:EUCLIDEAN)。 - `EUCLIDEAN` 是欧几里德距离。 - `COSINE` 是余弦距禂。 - `HAMMING` 是汉明距离。 - `pre_delete_collection` 如果为True,则如果存在,将删除集合。 (默认值:False) - 用于测试。 """ if ids is None: ids = [str(uuid.uuid4()) for _ in texts] if not metadatas: metadatas = [{} for _ in texts] connection_string = cls.__get_connection_string(kwargs) store = cls( connection_string=connection_string, collection_name=collection_name, embedding_function=embedding, pre_delete_collection=pre_delete_collection, distance_strategy=distance_strategy, ) store.add_embeddings( texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs ) store.create_hnsw_index(**kwargs) return store
[docs] def add_embeddings( self, texts: List[str], embeddings: List[List[float]], metadatas: List[dict], ids: List[str], **kwargs: Any, ) -> None: with Session(self._conn) as session: for text, metadata, embedding, id in zip(texts, metadatas, embeddings, ids): embedding_store = self.EmbeddingStore( embedding=embedding, document=text, cmetadata=metadata, custom_id=id, ) session.add(embedding_store) session.commit()
[docs] def add_texts( self, texts: Iterable[str], metadatas: Optional[List[dict]] = None, ids: Optional[List[str]] = None, **kwargs: Any, ) -> List[str]: if ids is None: ids = [str(uuid.uuid4()) for _ in texts] embeddings = self.embedding_function.embed_documents(list(texts)) if not metadatas: metadatas = [{} for _ in texts] with Session(self._conn) as session: for text, metadata, embedding, id in zip(texts, metadatas, embeddings, ids): embedding_store = self.EmbeddingStore( embedding=embedding, document=text, cmetadata=metadata, custom_id=id, ) session.add(embedding_store) session.commit() return ids
def _results_to_docs_and_scores(self, results: Any) -> List[Tuple[Document, float]]: """返回结果中的文档和分数。""" docs = [ ( Document( page_content=result.EmbeddingStore.document, metadata=result.EmbeddingStore.cmetadata, ), result.distance if self.embedding_function is not None else None, ) for result in results ] return docs
[docs] def similarity_search_with_score( self, query: str, k: int = 4, filter: Optional[dict] = None, ) -> List[Tuple[Document, float]]: embedding = self.embedding_function.embed_query(query) docs = self.similarity_search_with_score_by_vector( embedding=embedding, k=k, filter=filter ) return docs
[docs] def similarity_search_with_score_by_vector( self, embedding: List[float], k: int = 4, filter: Optional[dict] = None, ) -> List[Tuple[Document, float]]: results = self.__query_collection(embedding=embedding, k=k, filter=filter) return self._results_to_docs_and_scores(results)
def __query_collection( self, embedding: List[float], k: int = 4, filter: Optional[dict] = None, ) -> List[Any]: with Session(self._conn) as session: set_enable_seqscan_stmt = sqlalchemy.text("SET enable_seqscan = off") set_init_k = sqlalchemy.text("SET hnsw.init_k = :k") session.execute(set_enable_seqscan_stmt) session.execute(set_init_k, {"k": k}) filter_by = None if filter is not None: filter_clauses = [] for key, value in filter.items(): IN = "in" if isinstance(value, dict) and IN in map(str.lower, value): value_case_insensitive = { k.lower(): v for k, v in value.items() } filter_by_metadata = self.EmbeddingStore.cmetadata[ key ].astext.in_(value_case_insensitive[IN]) filter_clauses.append(filter_by_metadata) else: filter_by_metadata = self.EmbeddingStore.cmetadata[ key ].astext == str(value) filter_clauses.append(filter_by_metadata) filter_by = sqlalchemy.and_(*filter_clauses) embedding = self._typed_arg_for_distance(embedding) query = session.query( self.EmbeddingStore, getattr(func, self.distance_function)( self.EmbeddingStore.embedding, embedding ).label("distance"), ) # Specify the columns you need here, e.g., EmbeddingStore.embedding if filter_by is not None: query = query.filter(filter_by) results: List[QueryResult] = ( query.order_by( self.EmbeddingStore.embedding.op(self._get_operator())(embedding) ) # Using PostgreSQL specific operator with the correct column name .limit(k) .all() ) return results
[docs] def similarity_search_by_vector( self, embedding: List[float], k: int = 4, filter: Optional[dict] = None, **kwargs: Any, ) -> List[Document]: docs_and_scores = self.similarity_search_with_score_by_vector( embedding=embedding, k=k, filter=filter ) return _results_to_docs(docs_and_scores)
[docs] @classmethod def from_texts( cls: Type[Lantern], texts: List[str], embedding: Embeddings, metadatas: Optional[List[dict]] = None, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, ids: Optional[List[str]] = None, pre_delete_collection: bool = False, **kwargs: Any, ) -> Lantern: """从文本列表初始化Lantern向量存储。 将使用提供的`embedding`类生成嵌入。 列表`ids`、`texts`、`metadatas`中元素的顺序应该匹配, 这样每一行将与正确的值关联。 需要Postgres连接字符串 "可以将其作为`connection_string`参数传递 或设置LANTERN_CONNECTION_STRING环境变量。 - `connection_string` 是用于postgres数据库的完全填充的连接字符串 - `texts` 要插入到集合中的文本。 - `embedding` 是将用于嵌入发送的文本的:class:`Embeddings`。 如果未发送任何内容,则将使用多语言Tensorflow通用句子编码器。 - `metadatas` 要插入到集合中的行元数据。 - `collection_name` 是要使用的集合的名称。(默认值:langchain) - 注意:这是嵌入数据将被存储的表的名称 表将在初始化存储时创建(如果不存在) 因此,请确保用户具有创建表的权限。 - `distance_strategy` 是要使用的距离策略。(默认值:EUCLIDEAN) - `EUCLIDEAN` 是欧几里德距离。 - `COSINE` 是余弦距离。 - `HAMMING` 是汉明距离。 - `ids` 要插入到集合中的行ID。 - `pre_delete_collection` 如果为True,则将删除集合(如果存在)。 (默认值:False) - 用于测试。 """ embeddings = embedding.embed_documents(list(texts)) return cls._initialize_from_embeddings( texts, embeddings, embedding, metadatas=metadatas, ids=ids, collection_name=collection_name, pre_delete_collection=pre_delete_collection, distance_strategy=distance_strategy, **kwargs, )
[docs] @classmethod def from_embeddings( cls, text_embeddings: List[Tuple[str, List[float]]], embedding: Embeddings, metadatas: Optional[List[dict]] = None, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, ids: Optional[List[str]] = None, pre_delete_collection: bool = False, distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, **kwargs: Any, ) -> Lantern: """从原始文档和预生成的嵌入中构建 Lantern 包装器。 需要Postgres连接字符串 "可以将其作为 `connection_string` 参数传递 或设置 LANTERN_CONNECTION_STRING 环境变量。 列表 `ids`、`text_embeddings`、`metadatas` 的元素顺序应该匹配, 这样每一行将与正确的值关联。 - `connection_string` 是用于 postgres 数据库的完全填充的连接字符串 - `text_embeddings` 是包含元组 (text, embedding) 的数组 用于插入到集合中。 - `embedding` 是将用于嵌入发送的文本的 :class:`Embeddings`。 如果未发送任何内容,则将使用多语言 Tensorflow Universal Sentence Encoder。 - `metadatas` 行元数据,用于插入到集合中。 - `collection_name` 是要使用的集合的名称。 (默认值: langchain) - 注意: 这是嵌入数据将被存储的表的名称 初始化存储时将创建该表(如果不存在) 因此,请确保用户具有创建表的权限。 - `ids` 要插入到集合中的行 ids。 - `pre_delete_collection` 如果为 True,则将删除该集合(如果存在)。 (默认值: False) - 用于测试。 - `distance_strategy` 是要使用的距离策略。 (默认值: EUCLIDEAN) - `EUCLIDEAN` 是欧几里德距离。 - `COSINE` 是余弦距禮。 - `HAMMING` 是汉明距离。 """ texts = [t[0] for t in text_embeddings] embeddings = [t[1] for t in text_embeddings] return cls._initialize_from_embeddings( texts, embeddings, embedding, metadatas=metadatas, ids=ids, collection_name=collection_name, pre_delete_collection=pre_delete_collection, distance_strategy=distance_strategy, **kwargs, )
[docs] @classmethod def from_existing_index( cls: Type[Lantern], embedding: Embeddings, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, pre_delete_collection: bool = False, distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, **kwargs: Any, ) -> Lantern: """获取现有Lantern存储库的实例。此方法将返回存储库的实例,而不会插入任何新的嵌入。 需要Postgres连接字符串 "可以将其作为`connection_string`参数传递 或设置LANTERN_CONNECTION_STRING环境变量。 - `connection_string`是一个Postgres连接字符串。 - `embedding`是将用于嵌入发送的文本的:class:`Embeddings`。如果未发送任何内容,则将使用多语言Tensorflow通用句子编码器。 - `collection_name`是要使用的集合的名称。(默认值:langchain) - 注意:这是嵌入数据将被存储的表的名称 表将在初始化存储时创建(如果不存在) 因此,请确保用户具有创建表的权限。 - `ids`要插入到集合中的行ID。 - `pre_delete_collection`如果为True,则将删除该集合(如果存在)。 (默认值:False) - 用于测试。 - `distance_strategy`是要使用的距离策略。(默认值:EUCLIDEAN) - `EUCLIDEAN`是欧氏距离。 - `COSINE`是余弦距禂。 - `HAMMING`是汉明距离。 """ connection_string = cls.__get_connection_string(kwargs) store = cls( connection_string=connection_string, collection_name=collection_name, embedding_function=embedding, pre_delete_collection=pre_delete_collection, distance_strategy=distance_strategy, ) return store
@classmethod def __get_connection_string(cls, kwargs: Dict[str, Any]) -> str: connection_string: str = get_from_dict_or_env( data=kwargs, key="connection_string", env_key="LANTERN_CONNECTION_STRING", ) if not connection_string: raise ValueError( "Postgres connection string is required" "Either pass it as `connection_string` parameter" "or set the LANTERN_CONNECTION_STRING variable." ) return connection_string
[docs] @classmethod def from_documents( cls: Type[Lantern], documents: List[Document], embedding: Embeddings, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, ids: Optional[List[str]] = None, pre_delete_collection: bool = False, **kwargs: Any, ) -> Lantern: """初始化一个包含一组文档的向量存储。 需要Postgres连接字符串 "可以通过`connection_string`参数传递 或设置LANTERN_CONNECTION_STRING环境变量。 - `connection_string`是一个Postgres连接字符串。 - `documents`是要初始化向量存储的:class:`Document`列表 - `embedding`是将用于嵌入发送的文本的:class:`Embeddings`。如果未发送任何内容,则将使用多语言Tensorflow通用句子编码器。 - `collection_name`是要使用的集合名称。(默认值:langchain) - 注意:这是将存储嵌入数据的表的名称 初始化存储时将创建该表(如果不存在) 因此,请确保用户具有创建表的权限。 - `distance_strategy`是要使用的距离策略。(默认值:EUCLIDEAN) - `EUCLIDEAN`是欧几里得距离。 - `COSINE`是余弦距离。 - `HAMMING`是汉明距离。 - `ids`要插入到集合中的行ID。 - `pre_delete_collection`如果为True,则将删除集合(如果存在)。 (默认值:False) - 用于测试。 """ texts = [d.page_content for d in documents] metadatas = [d.metadata for d in documents] connection_string = cls.__get_connection_string(kwargs) kwargs["connection_string"] = connection_string return cls.from_texts( texts=texts, pre_delete_collection=pre_delete_collection, embedding=embedding, metadatas=metadatas, ids=ids, collection_name=collection_name, distance_strategy=distance_strategy, **kwargs, )
[docs] def max_marginal_relevance_search_with_score_by_vector( self, embedding: List[float], k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, filter: Optional[Dict[str, str]] = None, **kwargs: Any, ) -> List[Tuple[Document, float]]: """使用最大边际相关性和分数返回所选文档的文档。 最大边际相关性优化了与查询的相似性和所选文档之间的多样性。 参数: embedding: 要查找相似文档的嵌入。 k (int): 要返回的文档数量。默认为4。 fetch_k (int): 要获取以传递给MMR算法的文档数量。默认为20。 lambda_mult (float): 0到1之间的数字,确定结果之间多样性的程度,其中0对应最大多样性,1对应最小多样性。默认为0.5。 filter (Optional[Dict[str, str]]): 按元数据筛选。默认为None。 返回: List[Tuple[Document, float]]: 通过最大边际相关性选择的文档列表,以及每个文档的得分。 """ results = self.__query_collection(embedding=embedding, k=fetch_k, filter=filter) embedding_list = [result.EmbeddingStore.embedding for result in results] mmr_selected = maximal_marginal_relevance( np.array(embedding, dtype=np.float32), embedding_list, k=k, lambda_mult=lambda_mult, ) candidates = self._results_to_docs_and_scores(results) return [r for i, r in enumerate(candidates) if i in mmr_selected]
[docs] def max_marginal_relevance_search_with_score( self, query: str, k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, filter: Optional[dict] = None, **kwargs: Any, ) -> List[Tuple[Document, float]]: """返回使用最大边际相关性和分数选择的文档。 最大边际相关性优化了与查询的相似性和所选文档之间的多样性。 参数: query (str): 要查找类似文档的文本。 k (int): 要返回的文档数量。默认为4。 fetch_k (int): 要获取以传递给MMR算法的文档数量。 默认为20。 lambda_mult (float): 0到1之间的数字,确定结果之间多样性的程度, 0对应最大多样性,1对应最小多样性。 默认为0.5。 filter (Optional[Dict[str, str]]): 按元数据过滤。默认为None。 返回: List[Tuple[Document, float]]: 通过最大边际相关性选择的文档列表, 以及每个文档的得分。 """ embedding = self.embedding_function.embed_query(query) docs = self.max_marginal_relevance_search_with_score_by_vector( embedding=embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, filter=filter, **kwargs, ) return docs
[docs] def max_marginal_relevance_search_by_vector( self, embedding: List[float], k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, filter: Optional[Dict[str, str]] = None, **kwargs: Any, ) -> List[Document]: """返回使用最大边际相关性选择的文档到嵌入向量。 最大边际相关性优化了与查询的相似性和所选文档之间的多样性。 参数: embedding (str): 要查找类似文档的文本。 k (int): 要返回的文档数量。默认为4。 fetch_k (int): 要获取以传递给MMR算法的文档数量。 默认为20。 lambda_mult (float): 0到1之间的数字,确定结果之间多样性的程度, 0对应最大多样性,1对应最小多样性。 默认为0.5。 filter (Optional[Dict[str, str]]): 按元数据筛选。默认为None。 返回: List[Document]: 通过最大边际相关性选择的文档列表。 """ docs_and_scores = self.max_marginal_relevance_search_with_score_by_vector( embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, filter=filter, **kwargs, ) return _results_to_docs(docs_and_scores)