Source code for langchain_community.retrievers.vespa_retriever
from __future__ import annotations
import json
from typing import Any, Dict, List, Literal, Optional, Sequence, Union
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
[docs]class VespaRetriever(BaseRetriever):
"""`Vespa` 检索器。"""
app: Any
"""Vespa应用程序用于查询。"""
body: Dict
"""查询的主体。"""
content_field: str
"""名称字段。"""
metadata_fields: Sequence[str]
"""元数据字段的名称。"""
def _query(self, body: Dict) -> List[Document]:
response = self.app.query(body)
if not str(response.status_code).startswith("2"):
raise RuntimeError(
"Could not retrieve data from Vespa. Error code: {}".format(
response.status_code
)
)
root = response.json["root"]
if "errors" in root:
raise RuntimeError(json.dumps(root["errors"]))
docs = []
for child in response.hits:
page_content = child["fields"].pop(self.content_field, "")
if self.metadata_fields == "*":
metadata = child["fields"]
else:
metadata = {mf: child["fields"].get(mf) for mf in self.metadata_fields}
metadata["id"] = child["id"]
docs.append(Document(page_content=page_content, metadata=metadata))
return docs
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
body = self.body.copy()
body["query"] = query
return self._query(body)
[docs] def get_relevant_documents_with_filter(
self, query: str, *, _filter: Optional[str] = None
) -> List[Document]:
body = self.body.copy()
_filter = f" and {_filter}" if _filter else ""
body["yql"] = body["yql"] + _filter
body["query"] = query
return self._query(body)
[docs] @classmethod
def from_params(
cls,
url: str,
content_field: str,
*,
k: Optional[int] = None,
metadata_fields: Union[Sequence[str], Literal["*"]] = (),
sources: Union[Sequence[str], Literal["*"], None] = None,
_filter: Optional[str] = None,
yql: Optional[str] = None,
**kwargs: Any,
) -> VespaRetriever:
"""从参数中实例化检索器。
参数:
url (str): Vespa应用程序的URL。
content_field (str): 返回作为文档页面内容的结果中的字段。
k (Optional[int]): 要返回的文档数量。默认为None。
metadata_fields(Sequence[str] or "*"): 包含在文档元数据中的结果字段。默认为空元组()。
sources (Sequence[str] or "*" or None): 要检索的来源。默认为None。
_filter (Optional[str]): 以YQL表示的文档过滤条件。默认为None。
yql (Optional[str]): 要使用的完整YQL查询。如果指定了_filter或sources,则不应指定。默认为None。
kwargs (Any): 添加到查询体中的关键字参数。
返回:
VespaRetriever: 实例化的VespaRetriever。
"""
try:
from vespa.application import Vespa
except ImportError:
raise ImportError(
"pyvespa is not installed, please install with `pip install pyvespa`"
)
app = Vespa(url)
body = kwargs.copy()
if yql and (sources or _filter):
raise ValueError(
"yql should only be specified if both sources and _filter are not "
"specified."
)
else:
if metadata_fields == "*":
_fields = "*"
body["summary"] = "short"
else:
_fields = ", ".join([content_field] + list(metadata_fields or []))
_sources = ", ".join(sources) if isinstance(sources, Sequence) else "*"
_filter = f" and {_filter}" if _filter else ""
yql = f"select {_fields} from sources {_sources} where userQuery(){_filter}"
body["yql"] = yql
if k:
body["hits"] = k
return cls(
app=app,
body=body,
content_field=content_field,
metadata_fields=metadata_fields,
)