"""SQL代理。"""
from __future__ import annotations
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Literal,
Optional,
Sequence,
Union,
cast,
)
from langchain_core.messages import AIMessage, SystemMessage
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
from langchain_core.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
)
from langchain_community.agent_toolkits.sql.prompt import (
SQL_FUNCTIONS_SUFFIX,
SQL_PREFIX,
)
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain_community.tools.sql_database.tool import (
InfoSQLDatabaseTool,
ListSQLDatabaseTool,
)
if TYPE_CHECKING:
from langchain.agents.agent import AgentExecutor
from langchain.agents.agent_types import AgentType
from langchain_core.callbacks import BaseCallbackManager
from langchain_core.language_models import BaseLanguageModel
from langchain_core.tools import BaseTool
from langchain_community.utilities.sql_database import SQLDatabase
[docs]def create_sql_agent(
llm: BaseLanguageModel,
toolkit: Optional[SQLDatabaseToolkit] = None,
agent_type: Optional[
Union[AgentType, Literal["openai-tools", "tool-calling"]]
] = None,
callback_manager: Optional[BaseCallbackManager] = None,
prefix: Optional[str] = None,
suffix: Optional[str] = None,
format_instructions: Optional[str] = None,
input_variables: Optional[List[str]] = None,
top_k: int = 10,
max_iterations: Optional[int] = 15,
max_execution_time: Optional[float] = None,
early_stopping_method: str = "force",
verbose: bool = False,
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
extra_tools: Sequence[BaseTool] = (),
*,
db: Optional[SQLDatabase] = None,
prompt: Optional[BasePromptTemplate] = None,
**kwargs: Any,
) -> AgentExecutor:
"""从LLM和工具包或数据库构建一个SQL代理。
参数:
llm:用于代理的语言模型。如果agent_type为“tool-calling”,则期望llm支持工具调用。
toolkit:SQLDatabaseToolkit供代理使用。必须提供“toolkit”或“db”中的一个。如果要为代理和工具包使用不同的模型,请指定“toolkit”。
agent_type:其中之一为“tool-calling”,“openai-tools”,“openai-functions”或“zero-shot-react-description”。默认为“zero-shot-react-description”。“tool-calling”推荐使用而不是传统的“openai-tools”和“openai-functions”类型。
callback_manager:已弃用。将“callbacks”键传递给'agent_executor_kwargs',以将构造函数回调传递给AgentExecutor。
prefix:提示前缀字符串。必须包含变量“top_k”和“dialect”。
suffix:提示后缀字符串。默认取决于代理类型。
format_instructions:传递给ZeroShotAgent.create_prompt()的格式说明,当'agent_type'为“zero-shot-react-description”时使用。否则忽略。
input_variables:已弃用。
top_k:默认查询的行数。
max_iterations:传递给AgentExecutor init。
max_execution_time:传递给AgentExecutor init。
early_stopping_method:传递给AgentExecutor init。
verbose:AgentExecutor的详细程度。
agent_executor_kwargs:任意额外的AgentExecutor参数。
extra_tools:除SQLDatabaseToolkit提供的工具之外,还要为代理提供的额外工具。
db:用于创建SQLDatabaseToolkit的SQLDatabase。使用'db'和'llm'创建工具包。必须提供“db”或“toolkit”中的一个。
prompt:完整的代理提示。prompt和{prefix,suffix,format_instructions,input_variables}是互斥的。
**kwargs:任意额外的Agent参数。
返回:
具有指定agent_type代理的AgentExecutor。
示例:
.. code-block:: python
from langchain_openai import ChatOpenAI
from langchain_community.agent_toolkits import create_sql_agent
from langchain_community.utilities import SQLDatabase
db = SQLDatabase.from_uri("sqlite:///Chinook.db")
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
agent_executor = create_sql_agent(llm, db=db, agent_type="tool-calling", verbose=True)
""" # noqa: E501
from langchain.agents import (
create_openai_functions_agent,
create_openai_tools_agent,
create_react_agent,
create_tool_calling_agent,
)
from langchain.agents.agent import (
AgentExecutor,
RunnableAgent,
RunnableMultiActionAgent,
)
from langchain.agents.agent_types import AgentType
if toolkit is None and db is None:
raise ValueError(
"Must provide exactly one of 'toolkit' or 'db'. Received neither."
)
if toolkit and db:
raise ValueError(
"Must provide exactly one of 'toolkit' or 'db'. Received both."
)
toolkit = toolkit or SQLDatabaseToolkit(llm=llm, db=db) # type: ignore[arg-type]
agent_type = agent_type or AgentType.ZERO_SHOT_REACT_DESCRIPTION
tools = toolkit.get_tools() + list(extra_tools)
if prompt is None:
prefix = prefix or SQL_PREFIX
prefix = prefix.format(dialect=toolkit.dialect, top_k=top_k)
else:
if "top_k" in prompt.input_variables:
prompt = prompt.partial(top_k=str(top_k))
if "dialect" in prompt.input_variables:
prompt = prompt.partial(dialect=toolkit.dialect)
if any(key in prompt.input_variables for key in ["table_info", "table_names"]):
db_context = toolkit.get_context()
if "table_info" in prompt.input_variables:
prompt = prompt.partial(table_info=db_context["table_info"])
tools = [
tool for tool in tools if not isinstance(tool, InfoSQLDatabaseTool)
]
if "table_names" in prompt.input_variables:
prompt = prompt.partial(table_names=db_context["table_names"])
tools = [
tool for tool in tools if not isinstance(tool, ListSQLDatabaseTool)
]
if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION:
if prompt is None:
from langchain.agents.mrkl import prompt as react_prompt
format_instructions = (
format_instructions or react_prompt.FORMAT_INSTRUCTIONS
)
template = "\n\n".join(
[
react_prompt.PREFIX,
"{tools}",
format_instructions,
react_prompt.SUFFIX,
]
)
prompt = PromptTemplate.from_template(template)
agent = RunnableAgent(
runnable=create_react_agent(llm, tools, prompt),
input_keys_arg=["input"],
return_keys_arg=["output"],
**kwargs,
)
elif agent_type == AgentType.OPENAI_FUNCTIONS:
if prompt is None:
messages: List = [
SystemMessage(content=cast(str, prefix)),
HumanMessagePromptTemplate.from_template("{input}"),
AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
prompt = ChatPromptTemplate.from_messages(messages)
agent = RunnableAgent(
runnable=create_openai_functions_agent(llm, tools, prompt), # type: ignore
input_keys_arg=["input"],
return_keys_arg=["output"],
**kwargs,
)
elif agent_type in ("openai-tools", "tool-calling"):
if prompt is None:
messages = [
SystemMessage(content=cast(str, prefix)),
HumanMessagePromptTemplate.from_template("{input}"),
AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
prompt = ChatPromptTemplate.from_messages(messages)
if agent_type == "openai-tools":
runnable = create_openai_tools_agent(llm, tools, prompt) # type: ignore
else:
runnable = create_tool_calling_agent(llm, tools, prompt) # type: ignore
agent = RunnableMultiActionAgent( # type: ignore[assignment]
runnable=runnable,
input_keys_arg=["input"],
return_keys_arg=["output"],
**kwargs,
)
else:
raise ValueError(
f"Agent type {agent_type} not supported at the moment. Must be one of "
"'tool-calling', 'openai-tools', 'openai-functions', or "
"'zero-shot-react-description'."
)
return AgentExecutor(
name="SQL Agent Executor",
agent=agent,
tools=tools,
callback_manager=callback_manager,
verbose=verbose,
max_iterations=max_iterations,
max_execution_time=max_execution_time,
early_stopping_method=early_stopping_method,
**(agent_executor_kwargs or {}),
)