Source code for langchain_community.retrievers.docarray

from enum import Enum
from typing import Any, Dict, List, Optional, Union

import numpy as np
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.retrievers import BaseRetriever

from langchain_community.vectorstores.utils import maximal_marginal_relevance


[docs]class SearchType(str, Enum): """要执行的搜索类型的枚举器。""" similarity = "similarity" mmr = "mmr"
[docs]class DocArrayRetriever(BaseRetriever): """`DocArray Document Indices` 检索器。 目前,它支持5种后端: InMemoryExactNNIndex、HnswDocumentIndex、QdrantDocumentIndex、 ElasticDocIndex 和 WeaviateDocumentIndex。 参数: index: 上述提到的索引实例之一 embeddings: 用于将文本表示为向量的嵌入模型 search_field: 用于在文档中搜索的字段。 应该是一个嵌入/向量/张量。 content_field: 表示文档模式中主要内容的字段。 将用作 `page_content`。其他所有内容将放入 `metadata`。 search_type: 要执行的搜索类型(相似性 / mmr) filters: 应用于文档检索的过滤器。 top_k: 要返回的文档数量""" index: Any embeddings: Embeddings search_field: str content_field: str search_type: SearchType = SearchType.similarity top_k: int = 1 filters: Optional[Any] = None class Config: """此pydantic对象的配置。""" arbitrary_types_allowed = True def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun, ) -> List[Document]: """获取与查询相关的文档。 参数: query: 要查找相关文档的字符串 返回: 相关文档的列表 """ query_emb = np.array(self.embeddings.embed_query(query)) if self.search_type == SearchType.similarity: results = self._similarity_search(query_emb) elif self.search_type == SearchType.mmr: results = self._mmr_search(query_emb) else: raise ValueError( f"Search type {self.search_type} does not exist. " f"Choose either 'similarity' or 'mmr'." ) return results def _search( self, query_emb: np.ndarray, top_k: int ) -> List[Union[Dict[str, Any], Any]]: """使用查询嵌入执行搜索并返回前k个文档。 参数: query_emb:表示为嵌入的查询 top_k:要返回的文档数 返回: 与查询匹配的前k个文档列表 """ from docarray.index import ElasticDocIndex, WeaviateDocumentIndex filter_args = {} search_field = self.search_field if isinstance(self.index, WeaviateDocumentIndex): filter_args["where_filter"] = self.filters search_field = "" elif isinstance(self.index, ElasticDocIndex): filter_args["query"] = self.filters else: filter_args["filter_query"] = self.filters if self.filters: query = ( self.index.build_query() # get empty query object .find( query=query_emb, search_field=search_field ) # add vector similarity search .filter(**filter_args) # add filter search .build(limit=top_k) # build the query ) # execute the combined query and return the results docs = self.index.execute_query(query) if hasattr(docs, "documents"): docs = docs.documents docs = docs[:top_k] else: docs = self.index.find( query=query_emb, search_field=search_field, limit=top_k ).documents return docs def _similarity_search(self, query_emb: np.ndarray) -> List[Document]: """执行相似性搜索。 参数: query_emb:以嵌入形式表示的查询 返回: 与查询最相似的文档列表 """ docs = self._search(query_emb=query_emb, top_k=self.top_k) results = [self._docarray_to_langchain_doc(doc) for doc in docs] return results def _mmr_search(self, query_emb: np.ndarray) -> List[Document]: """执行最大边际相关性(mmr)搜索。 参数: query_emb:以嵌入表示的查询 返回: 与查询相关的多样性文档列表 """ docs = self._search(query_emb=query_emb, top_k=20) mmr_selected = maximal_marginal_relevance( query_emb, [ doc[self.search_field] if isinstance(doc, dict) else getattr(doc, self.search_field) for doc in docs ], k=self.top_k, ) results = [self._docarray_to_langchain_doc(docs[idx]) for idx in mmr_selected] return results def _docarray_to_langchain_doc(self, doc: Union[Dict[str, Any], Any]) -> Document: """将一个DocArray文档(也可能是一个字典)转换为langchain文档格式。 DocArray文档可以包含任意字段,因此映射是按照以下方式进行的: page_content <-> content_field metadata <-> 所有其他字段,不包括张量和嵌入(即浮点数、整数、字符串) 参数: doc: DocArray文档 返回: 以langchain格式的文档 引发: ValueError: 如果文档不包含内容字段 """ fields = doc.keys() if isinstance(doc, dict) else doc.__fields__ if self.content_field not in fields: raise ValueError( f"Document does not contain the content field - {self.content_field}." ) lc_doc = Document( page_content=doc[self.content_field] if isinstance(doc, dict) else getattr(doc, self.content_field) ) for name in fields: value = doc[name] if isinstance(doc, dict) else getattr(doc, name) if ( isinstance(value, (str, int, float, bool)) and name != self.content_field ): lc_doc.metadata[name] = value return lc_doc