Source code for langchain.retrievers.multi_vector

from enum import Enum
from typing import Dict, List, Optional

from langchain_core.callbacks import (
    AsyncCallbackManagerForRetrieverRun,
    CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.retrievers import BaseRetriever
from langchain_core.stores import BaseStore, ByteStore
from langchain_core.vectorstores import VectorStore

from langchain.storage._lc_store import create_kv_docstore


[docs]class SearchType(str, Enum): """要执行的搜索类型的枚举器。""" similarity = "similarity" """相似性搜索。""" mmr = "mmr" """相似性搜索的最大边际相关性重新排序。"""
[docs]class MultiVectorRetriever(BaseRetriever): """从同一文档的多个嵌入集中检索。""" vectorstore: VectorStore """用于存储小块和它们的嵌入向量的基础向量存储库。""" byte_store: Optional[ByteStore] = None """父文档的低级后备存储层""" docstore: BaseStore[str, Document] """父文档的存储接口""" id_key: str = "doc_id" search_kwargs: dict = Field(default_factory=dict) """传递给搜索函数的关键字参数。""" search_type: SearchType = SearchType.similarity """要执行的搜索类型(相似性 / mmr)""" @root_validator(pre=True) def shim_docstore(cls, values: Dict) -> Dict: byte_store = values.get("byte_store") docstore = values.get("docstore") if byte_store is not None: docstore = create_kv_docstore(byte_store) elif docstore is None: raise Exception("You must pass a `byte_store` parameter.") values["docstore"] = docstore return values def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: """获取与查询相关的文档。 参数: query:要查找相关文档的字符串 run_manager:要使用的回调处理程序 返回: 相关文档的列表 """ if self.search_type == SearchType.mmr: sub_docs = self.vectorstore.max_marginal_relevance_search( query, **self.search_kwargs ) else: sub_docs = self.vectorstore.similarity_search(query, **self.search_kwargs) # We do this to maintain the order of the ids that are returned ids = [] for d in sub_docs: if self.id_key in d.metadata and d.metadata[self.id_key] not in ids: ids.append(d.metadata[self.id_key]) docs = self.docstore.mget(ids) return [d for d in docs if d is not None] async def _aget_relevant_documents( self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun ) -> List[Document]: """异步获取与查询相关的文档。 参数: query: 要查找相关文档的字符串 run_manager: 要使用的回调处理程序 返回: 相关文档的列表 """ if self.search_type == SearchType.mmr: sub_docs = await self.vectorstore.amax_marginal_relevance_search( query, **self.search_kwargs ) else: sub_docs = await self.vectorstore.asimilarity_search( query, **self.search_kwargs ) # We do this to maintain the order of the ids that are returned ids = [] for d in sub_docs: if self.id_key in d.metadata and d.metadata[self.id_key] not in ids: ids.append(d.metadata[self.id_key]) docs = await self.docstore.amget(ids) return [d for d in docs if d is not None]