Source code for langchain_community.agent_toolkits.sql.base

"""SQL代理。"""
from __future__ import annotations

from typing import (
    TYPE_CHECKING,
    Any,
    Dict,
    List,
    Literal,
    Optional,
    Sequence,
    Union,
    cast,
)

from langchain_core.messages import AIMessage, SystemMessage
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
from langchain_core.prompts.chat import (
    ChatPromptTemplate,
    HumanMessagePromptTemplate,
    MessagesPlaceholder,
)

from langchain_community.agent_toolkits.sql.prompt import (
    SQL_FUNCTIONS_SUFFIX,
    SQL_PREFIX,
)
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain_community.tools.sql_database.tool import (
    InfoSQLDatabaseTool,
    ListSQLDatabaseTool,
)

if TYPE_CHECKING:
    from langchain.agents.agent import AgentExecutor
    from langchain.agents.agent_types import AgentType
    from langchain_core.callbacks import BaseCallbackManager
    from langchain_core.language_models import BaseLanguageModel
    from langchain_core.tools import BaseTool

    from langchain_community.utilities.sql_database import SQLDatabase


[docs]def create_sql_agent( llm: BaseLanguageModel, toolkit: Optional[SQLDatabaseToolkit] = None, agent_type: Optional[ Union[AgentType, Literal["openai-tools", "tool-calling"]] ] = None, callback_manager: Optional[BaseCallbackManager] = None, prefix: Optional[str] = None, suffix: Optional[str] = None, format_instructions: Optional[str] = None, input_variables: Optional[List[str]] = None, top_k: int = 10, max_iterations: Optional[int] = 15, max_execution_time: Optional[float] = None, early_stopping_method: str = "force", verbose: bool = False, agent_executor_kwargs: Optional[Dict[str, Any]] = None, extra_tools: Sequence[BaseTool] = (), *, db: Optional[SQLDatabase] = None, prompt: Optional[BasePromptTemplate] = None, **kwargs: Any, ) -> AgentExecutor: """从LLM和工具包或数据库构建一个SQL代理。 参数: llm:用于代理的语言模型。如果agent_type为“tool-calling”,则期望llm支持工具调用。 toolkit:SQLDatabaseToolkit供代理使用。必须提供“toolkit”或“db”中的一个。如果要为代理和工具包使用不同的模型,请指定“toolkit”。 agent_type:其中之一为“tool-calling”,“openai-tools”,“openai-functions”或“zero-shot-react-description”。默认为“zero-shot-react-description”。“tool-calling”推荐使用而不是传统的“openai-tools”和“openai-functions”类型。 callback_manager:已弃用。将“callbacks”键传递给'agent_executor_kwargs',以将构造函数回调传递给AgentExecutor。 prefix:提示前缀字符串。必须包含变量“top_k”和“dialect”。 suffix:提示后缀字符串。默认取决于代理类型。 format_instructions:传递给ZeroShotAgent.create_prompt()的格式说明,当'agent_type'为“zero-shot-react-description”时使用。否则忽略。 input_variables:已弃用。 top_k:默认查询的行数。 max_iterations:传递给AgentExecutor init。 max_execution_time:传递给AgentExecutor init。 early_stopping_method:传递给AgentExecutor init。 verbose:AgentExecutor的详细程度。 agent_executor_kwargs:任意额外的AgentExecutor参数。 extra_tools:除SQLDatabaseToolkit提供的工具之外,还要为代理提供的额外工具。 db:用于创建SQLDatabaseToolkit的SQLDatabase。使用'db'和'llm'创建工具包。必须提供“db”或“toolkit”中的一个。 prompt:完整的代理提示。prompt和{prefix,suffix,format_instructions,input_variables}是互斥的。 **kwargs:任意额外的Agent参数。 返回: 具有指定agent_type代理的AgentExecutor。 示例: .. code-block:: python from langchain_openai import ChatOpenAI from langchain_community.agent_toolkits import create_sql_agent from langchain_community.utilities import SQLDatabase db = SQLDatabase.from_uri("sqlite:///Chinook.db") llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0) agent_executor = create_sql_agent(llm, db=db, agent_type="tool-calling", verbose=True) """ # noqa: E501 from langchain.agents import ( create_openai_functions_agent, create_openai_tools_agent, create_react_agent, create_tool_calling_agent, ) from langchain.agents.agent import ( AgentExecutor, RunnableAgent, RunnableMultiActionAgent, ) from langchain.agents.agent_types import AgentType if toolkit is None and db is None: raise ValueError( "Must provide exactly one of 'toolkit' or 'db'. Received neither." ) if toolkit and db: raise ValueError( "Must provide exactly one of 'toolkit' or 'db'. Received both." ) toolkit = toolkit or SQLDatabaseToolkit(llm=llm, db=db) # type: ignore[arg-type] agent_type = agent_type or AgentType.ZERO_SHOT_REACT_DESCRIPTION tools = toolkit.get_tools() + list(extra_tools) if prompt is None: prefix = prefix or SQL_PREFIX prefix = prefix.format(dialect=toolkit.dialect, top_k=top_k) else: if "top_k" in prompt.input_variables: prompt = prompt.partial(top_k=str(top_k)) if "dialect" in prompt.input_variables: prompt = prompt.partial(dialect=toolkit.dialect) if any(key in prompt.input_variables for key in ["table_info", "table_names"]): db_context = toolkit.get_context() if "table_info" in prompt.input_variables: prompt = prompt.partial(table_info=db_context["table_info"]) tools = [ tool for tool in tools if not isinstance(tool, InfoSQLDatabaseTool) ] if "table_names" in prompt.input_variables: prompt = prompt.partial(table_names=db_context["table_names"]) tools = [ tool for tool in tools if not isinstance(tool, ListSQLDatabaseTool) ] if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION: if prompt is None: from langchain.agents.mrkl import prompt as react_prompt format_instructions = ( format_instructions or react_prompt.FORMAT_INSTRUCTIONS ) template = "\n\n".join( [ react_prompt.PREFIX, "{tools}", format_instructions, react_prompt.SUFFIX, ] ) prompt = PromptTemplate.from_template(template) agent = RunnableAgent( runnable=create_react_agent(llm, tools, prompt), input_keys_arg=["input"], return_keys_arg=["output"], **kwargs, ) elif agent_type == AgentType.OPENAI_FUNCTIONS: if prompt is None: messages: List = [ SystemMessage(content=cast(str, prefix)), HumanMessagePromptTemplate.from_template("{input}"), AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX), MessagesPlaceholder(variable_name="agent_scratchpad"), ] prompt = ChatPromptTemplate.from_messages(messages) agent = RunnableAgent( runnable=create_openai_functions_agent(llm, tools, prompt), # type: ignore input_keys_arg=["input"], return_keys_arg=["output"], **kwargs, ) elif agent_type in ("openai-tools", "tool-calling"): if prompt is None: messages = [ SystemMessage(content=cast(str, prefix)), HumanMessagePromptTemplate.from_template("{input}"), AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX), MessagesPlaceholder(variable_name="agent_scratchpad"), ] prompt = ChatPromptTemplate.from_messages(messages) if agent_type == "openai-tools": runnable = create_openai_tools_agent(llm, tools, prompt) # type: ignore else: runnable = create_tool_calling_agent(llm, tools, prompt) # type: ignore agent = RunnableMultiActionAgent( # type: ignore[assignment] runnable=runnable, input_keys_arg=["input"], return_keys_arg=["output"], **kwargs, ) else: raise ValueError( f"Agent type {agent_type} not supported at the moment. Must be one of " "'tool-calling', 'openai-tools', 'openai-functions', or " "'zero-shot-react-description'." ) return AgentExecutor( name="SQL Agent Executor", agent=agent, tools=tools, callback_manager=callback_manager, verbose=verbose, max_iterations=max_iterations, max_execution_time=max_execution_time, early_stopping_method=early_stopping_method, **(agent_executor_kwargs or {}), )