langchain_community.agent_toolkits.sql.toolkit ηζΊδ»£η
"""Toolkit for interacting with an SQL database."""
from typing import List
from langchain_core.caches import BaseCache as BaseCache
from langchain_core.callbacks import Callbacks as Callbacks
from langchain_core.language_models import BaseLanguageModel
from langchain_core.tools import BaseTool
from langchain_core.tools.base import BaseToolkit
from pydantic import ConfigDict, Field
from langchain_community.tools.sql_database.tool import (
InfoSQLDatabaseTool,
ListSQLDatabaseTool,
QuerySQLCheckerTool,
QuerySQLDatabaseTool,
)
from langchain_community.tools.sql_database.tool import (
QuerySQLDataBaseTool as QuerySQLDataBaseTool, # keep import for backwards compat.
)
from langchain_community.utilities.sql_database import SQLDatabase
[docs]
class SQLDatabaseToolkit(BaseToolkit):
"""SQLDatabaseToolkit for interacting with SQL databases.
Setup:
Install ``langchain-community``.
.. code-block:: bash
pip install -U langchain-community
Key init args:
db: SQLDatabase
The SQL database.
llm: BaseLanguageModel
The language model (for use with QuerySQLCheckerTool)
Instantiate:
.. code-block:: python
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain_community.utilities.sql_database import SQLDatabase
from langchain_openai import ChatOpenAI
db = SQLDatabase.from_uri("sqlite:///Chinook.db")
llm = ChatOpenAI(temperature=0)
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
Tools:
.. code-block:: python
toolkit.get_tools()
Use within an agent:
.. code-block:: python
from langchain import hub
from langgraph.prebuilt import create_react_agent
# Pull prompt (or define your own)
prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt")
system_message = prompt_template.format(dialect="SQLite", top_k=5)
# Create agent
agent_executor = create_react_agent(
llm, toolkit.get_tools(), state_modifier=system_message
)
# Query agent
example_query = "Which country's customers spent the most?"
events = agent_executor.stream(
{"messages": [("user", example_query)]},
stream_mode="values",
)
for event in events:
event["messages"][-1].pretty_print()
""" # noqa: E501
db: SQLDatabase = Field(exclude=True)
llm: BaseLanguageModel = Field(exclude=True)
@property
def dialect(self) -> str:
"""Return string representation of SQL dialect to use."""
return self.db.dialect
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
[docs]
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
list_sql_database_tool = ListSQLDatabaseTool(db=self.db)
info_sql_database_tool_description = (
"Input to this tool is a comma-separated list of tables, output is the "
"schema and sample rows for those tables. "
"Be sure that the tables actually exist by calling "
f"{list_sql_database_tool.name} first! "
"Example Input: table1, table2, table3"
)
info_sql_database_tool = InfoSQLDatabaseTool(
db=self.db, description=info_sql_database_tool_description
)
query_sql_database_tool_description = (
"Input to this tool is a detailed and correct SQL query, output is a "
"result from the database. If the query is not correct, an error message "
"will be returned. If an error is returned, rewrite the query, check the "
"query, and try again. If you encounter an issue with Unknown column "
f"'xxxx' in 'field list', use {info_sql_database_tool.name} "
"to query the correct table fields."
)
query_sql_database_tool = QuerySQLDatabaseTool(
db=self.db, description=query_sql_database_tool_description
)
query_sql_checker_tool_description = (
"Use this tool to double check if your query is correct before executing "
"it. Always use this tool before executing a query with "
f"{query_sql_database_tool.name}!"
)
query_sql_checker_tool = QuerySQLCheckerTool(
db=self.db, llm=self.llm, description=query_sql_checker_tool_description
)
return [
query_sql_database_tool,
info_sql_database_tool,
list_sql_database_tool,
query_sql_checker_tool,
]
[docs]
def get_context(self) -> dict:
"""Return db context that you may want in agent prompt."""
return self.db.get_context()
SQLDatabaseToolkit.model_rebuild()