Source code for langchain_community.retrievers.docarray
from enum import Enum
from typing import Any, Dict, List, Optional, Union
import numpy as np
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.retrievers import BaseRetriever
from langchain_community.vectorstores.utils import maximal_marginal_relevance
[docs]class SearchType(str, Enum):
"""要执行的搜索类型的枚举器。"""
similarity = "similarity"
mmr = "mmr"
[docs]class DocArrayRetriever(BaseRetriever):
"""`DocArray Document Indices` 检索器。
目前,它支持5种后端:
InMemoryExactNNIndex、HnswDocumentIndex、QdrantDocumentIndex、
ElasticDocIndex 和 WeaviateDocumentIndex。
参数:
index: 上述提到的索引实例之一
embeddings: 用于将文本表示为向量的嵌入模型
search_field: 用于在文档中搜索的字段。
应该是一个嵌入/向量/张量。
content_field: 表示文档模式中主要内容的字段。
将用作 `page_content`。其他所有内容将放入 `metadata`。
search_type: 要执行的搜索类型(相似性 / mmr)
filters: 应用于文档检索的过滤器。
top_k: 要返回的文档数量"""
index: Any
embeddings: Embeddings
search_field: str
content_field: str
search_type: SearchType = SearchType.similarity
top_k: int = 1
filters: Optional[Any] = None
class Config:
"""此pydantic对象的配置。"""
arbitrary_types_allowed = True
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
) -> List[Document]:
"""获取与查询相关的文档。
参数:
query: 要查找相关文档的字符串
返回:
相关文档的列表
"""
query_emb = np.array(self.embeddings.embed_query(query))
if self.search_type == SearchType.similarity:
results = self._similarity_search(query_emb)
elif self.search_type == SearchType.mmr:
results = self._mmr_search(query_emb)
else:
raise ValueError(
f"Search type {self.search_type} does not exist. "
f"Choose either 'similarity' or 'mmr'."
)
return results
def _search(
self, query_emb: np.ndarray, top_k: int
) -> List[Union[Dict[str, Any], Any]]:
"""使用查询嵌入执行搜索并返回前k个文档。
参数:
query_emb:表示为嵌入的查询
top_k:要返回的文档数
返回:
与查询匹配的前k个文档列表
"""
from docarray.index import ElasticDocIndex, WeaviateDocumentIndex
filter_args = {}
search_field = self.search_field
if isinstance(self.index, WeaviateDocumentIndex):
filter_args["where_filter"] = self.filters
search_field = ""
elif isinstance(self.index, ElasticDocIndex):
filter_args["query"] = self.filters
else:
filter_args["filter_query"] = self.filters
if self.filters:
query = (
self.index.build_query() # get empty query object
.find(
query=query_emb, search_field=search_field
) # add vector similarity search
.filter(**filter_args) # add filter search
.build(limit=top_k) # build the query
)
# execute the combined query and return the results
docs = self.index.execute_query(query)
if hasattr(docs, "documents"):
docs = docs.documents
docs = docs[:top_k]
else:
docs = self.index.find(
query=query_emb, search_field=search_field, limit=top_k
).documents
return docs
def _similarity_search(self, query_emb: np.ndarray) -> List[Document]:
"""执行相似性搜索。
参数:
query_emb:以嵌入形式表示的查询
返回:
与查询最相似的文档列表
"""
docs = self._search(query_emb=query_emb, top_k=self.top_k)
results = [self._docarray_to_langchain_doc(doc) for doc in docs]
return results
def _mmr_search(self, query_emb: np.ndarray) -> List[Document]:
"""执行最大边际相关性(mmr)搜索。
参数:
query_emb:以嵌入表示的查询
返回:
与查询相关的多样性文档列表
"""
docs = self._search(query_emb=query_emb, top_k=20)
mmr_selected = maximal_marginal_relevance(
query_emb,
[
doc[self.search_field]
if isinstance(doc, dict)
else getattr(doc, self.search_field)
for doc in docs
],
k=self.top_k,
)
results = [self._docarray_to_langchain_doc(docs[idx]) for idx in mmr_selected]
return results
def _docarray_to_langchain_doc(self, doc: Union[Dict[str, Any], Any]) -> Document:
"""将一个DocArray文档(也可能是一个字典)转换为langchain文档格式。
DocArray文档可以包含任意字段,因此映射是按照以下方式进行的:
page_content <-> content_field
metadata <-> 所有其他字段,不包括张量和嵌入(即浮点数、整数、字符串)
参数:
doc: DocArray文档
返回:
以langchain格式的文档
引发:
ValueError: 如果文档不包含内容字段
"""
fields = doc.keys() if isinstance(doc, dict) else doc.__fields__
if self.content_field not in fields:
raise ValueError(
f"Document does not contain the content field - {self.content_field}."
)
lc_doc = Document(
page_content=doc[self.content_field]
if isinstance(doc, dict)
else getattr(doc, self.content_field)
)
for name in fields:
value = doc[name] if isinstance(doc, dict) else getattr(doc, name)
if (
isinstance(value, (str, int, float, bool))
and name != self.content_field
):
lc_doc.metadata[name] = value
return lc_doc