Source code for langchain_community.retrievers.vespa_retriever

from __future__ import annotations

import json
from typing import Any, Dict, List, Literal, Optional, Sequence, Union

from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever


[docs]class VespaRetriever(BaseRetriever): """`Vespa` 检索器。""" app: Any """Vespa应用程序用于查询。""" body: Dict """查询的主体。""" content_field: str """名称字段。""" metadata_fields: Sequence[str] """元数据字段的名称。""" def _query(self, body: Dict) -> List[Document]: response = self.app.query(body) if not str(response.status_code).startswith("2"): raise RuntimeError( "Could not retrieve data from Vespa. Error code: {}".format( response.status_code ) ) root = response.json["root"] if "errors" in root: raise RuntimeError(json.dumps(root["errors"])) docs = [] for child in response.hits: page_content = child["fields"].pop(self.content_field, "") if self.metadata_fields == "*": metadata = child["fields"] else: metadata = {mf: child["fields"].get(mf) for mf in self.metadata_fields} metadata["id"] = child["id"] docs.append(Document(page_content=page_content, metadata=metadata)) return docs def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: body = self.body.copy() body["query"] = query return self._query(body)
[docs] def get_relevant_documents_with_filter( self, query: str, *, _filter: Optional[str] = None ) -> List[Document]: body = self.body.copy() _filter = f" and {_filter}" if _filter else "" body["yql"] = body["yql"] + _filter body["query"] = query return self._query(body)
[docs] @classmethod def from_params( cls, url: str, content_field: str, *, k: Optional[int] = None, metadata_fields: Union[Sequence[str], Literal["*"]] = (), sources: Union[Sequence[str], Literal["*"], None] = None, _filter: Optional[str] = None, yql: Optional[str] = None, **kwargs: Any, ) -> VespaRetriever: """从参数中实例化检索器。 参数: url (str): Vespa应用程序的URL。 content_field (str): 返回作为文档页面内容的结果中的字段。 k (Optional[int]): 要返回的文档数量。默认为None。 metadata_fields(Sequence[str] or "*"): 包含在文档元数据中的结果字段。默认为空元组()。 sources (Sequence[str] or "*" or None): 要检索的来源。默认为None。 _filter (Optional[str]): 以YQL表示的文档过滤条件。默认为None。 yql (Optional[str]): 要使用的完整YQL查询。如果指定了_filter或sources,则不应指定。默认为None。 kwargs (Any): 添加到查询体中的关键字参数。 返回: VespaRetriever: 实例化的VespaRetriever。 """ try: from vespa.application import Vespa except ImportError: raise ImportError( "pyvespa is not installed, please install with `pip install pyvespa`" ) app = Vespa(url) body = kwargs.copy() if yql and (sources or _filter): raise ValueError( "yql should only be specified if both sources and _filter are not " "specified." ) else: if metadata_fields == "*": _fields = "*" body["summary"] = "short" else: _fields = ", ".join([content_field] + list(metadata_fields or [])) _sources = ", ".join(sources) if isinstance(sources, Sequence) else "*" _filter = f" and {_filter}" if _filter else "" yql = f"select {_fields} from sources {_sources} where userQuery(){_filter}" body["yql"] = yql if k: body["hits"] = k return cls( app=app, body=body, content_field=content_field, metadata_fields=metadata_fields, )