class QueryEngineTool(AsyncBaseTool):
"""查询引擎工具。
一个利用查询引擎的工具。
Args:
query_engine (BaseQueryEngine): 一个查询引擎。
metadata (ToolMetadata): 查询引擎的相关元数据。"""
def __init__(
self,
query_engine: BaseQueryEngine,
metadata: ToolMetadata,
resolve_input_errors: bool = True,
) -> None:
self._query_engine = query_engine
self._metadata = metadata
self._resolve_input_errors = resolve_input_errors
@classmethod
def from_defaults(
cls,
query_engine: BaseQueryEngine,
name: Optional[str] = None,
description: Optional[str] = None,
return_direct: bool = False,
resolve_input_errors: bool = True,
) -> "QueryEngineTool":
name = name or DEFAULT_NAME
description = description or DEFAULT_DESCRIPTION
metadata = ToolMetadata(
name=name, description=description, return_direct=return_direct
)
return cls(
query_engine=query_engine,
metadata=metadata,
resolve_input_errors=resolve_input_errors,
)
@property
def query_engine(self) -> BaseQueryEngine:
return self._query_engine
@property
def metadata(self) -> ToolMetadata:
return self._metadata
def call(self, *args: Any, **kwargs: Any) -> ToolOutput:
query_str = self._get_query_str(*args, **kwargs)
response = self._query_engine.query(query_str)
return ToolOutput(
content=str(response),
tool_name=self.metadata.name,
raw_input={"input": query_str},
raw_output=response,
)
async def acall(self, *args: Any, **kwargs: Any) -> ToolOutput:
query_str = self._get_query_str(*args, **kwargs)
response = await self._query_engine.aquery(query_str)
return ToolOutput(
content=str(response),
tool_name=self.metadata.name,
raw_input={"input": query_str},
raw_output=response,
)
def as_langchain_tool(self) -> "LlamaIndexTool":
from llama_index.core.langchain_helpers.agents.tools import (
IndexToolConfig,
LlamaIndexTool,
)
tool_config = IndexToolConfig(
query_engine=self.query_engine,
name=self.metadata.name,
description=self.metadata.description,
)
return LlamaIndexTool.from_tool_config(tool_config=tool_config)
def _get_query_str(self, *args, **kwargs) -> str:
if args is not None and len(args) > 0:
query_str = str(args[0])
elif kwargs is not None and "input" in kwargs:
# NOTE: this assumes our default function schema of `input`
query_str = kwargs["input"]
elif kwargs is not None and self._resolve_input_errors:
query_str = str(kwargs)
else:
raise ValueError(
"Cannot call query engine without specifying `input` parameter."
)
return query_str