Source code for langchain_community.tools.sql_database.tool

# flake8: noqa
"""与SQL数据库交互的工具。"""
from typing import Any, Dict, Optional, Sequence, Type, Union

from sqlalchemy.engine import Result

from langchain_core.pydantic_v1 import BaseModel, Field, root_validator

from langchain_core.language_models import BaseLanguageModel
from langchain_core.callbacks import (
    AsyncCallbackManagerForToolRun,
    CallbackManagerForToolRun,
)
from langchain_core.prompts import PromptTemplate
from langchain_community.utilities.sql_database import SQLDatabase
from langchain_core.tools import BaseTool
from langchain_community.tools.sql_database.prompt import QUERY_CHECKER


[docs]class BaseSQLDatabaseTool(BaseModel): """SQL数据库交互的基础工具。""" db: SQLDatabase = Field(exclude=True) class Config(BaseTool.Config): pass
class _QuerySQLDataBaseToolInput(BaseModel): query: str = Field(..., description="A detailed and correct SQL query.")
[docs]class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool): """工具用于查询SQL数据库。""" name: str = "sql_db_query" description: str = """ Execute a SQL query against the database and get back the result.. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. """ args_schema: Type[BaseModel] = _QuerySQLDataBaseToolInput def _run( self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None, ) -> Union[str, Sequence[Dict[str, Any]], Result]: """执行查询,返回结果或错误消息。""" return self.db.run_no_throw(query)
class _InfoSQLDatabaseToolInput(BaseModel): table_names: str = Field( ..., description=( "A comma-separated list of the table names for which to return the schema. " "Example input: 'table1, table2, table3'" ), )
[docs]class InfoSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): """用于获取SQL数据库元数据的工具。""" name: str = "sql_db_schema" description: str = "Get the schema and sample rows for the specified SQL tables." args_schema: Type[BaseModel] = _InfoSQLDatabaseToolInput def _run( self, table_names: str, run_manager: Optional[CallbackManagerForToolRun] = None, ) -> str: """获取以逗号分隔的表的模式。""" return self.db.get_table_info_no_throw( [t.strip() for t in table_names.split(",")] )
class _ListSQLDataBaseToolInput(BaseModel): tool_input: str = Field("", description="An empty string")
[docs]class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): """获取表名的工具。""" name: str = "sql_db_list_tables" description: str = "Input is an empty string, output is a comma-separated list of tables in the database." args_schema: Type[BaseModel] = _ListSQLDataBaseToolInput def _run( self, tool_input: str = "", run_manager: Optional[CallbackManagerForToolRun] = None, ) -> str: """获取一个以逗号分隔的表名列表。""" return ", ".join(self.db.get_usable_table_names())
class _QuerySQLCheckerToolInput(BaseModel): query: str = Field(..., description="A detailed and SQL query to be checked.")
[docs]class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool): """使用LLM来检查查询是否正确。 改编自https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/""" template: str = QUERY_CHECKER llm: BaseLanguageModel llm_chain: Any = Field(init=False) name: str = "sql_db_query_checker" description: str = """ Use this tool to double check if your query is correct before executing it. Always use this tool before executing a query with sql_db_query! """ args_schema: Type[BaseModel] = _QuerySQLCheckerToolInput @root_validator(pre=True) def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]: if "llm_chain" not in values: from langchain.chains.llm import LLMChain values["llm_chain"] = LLMChain( llm=values.get("llm"), # type: ignore[arg-type] prompt=PromptTemplate( template=QUERY_CHECKER, input_variables=["dialect", "query"] ), ) if values["llm_chain"].prompt.input_variables != ["dialect", "query"]: raise ValueError( "LLM chain for QueryCheckerTool must have input variables ['query', 'dialect']" ) return values def _run( self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None, ) -> str: """使用LLM来检查查询。""" return self.llm_chain.predict( query=query, dialect=self.db.dialect, callbacks=run_manager.get_child() if run_manager else None, ) async def _arun( self, query: str, run_manager: Optional[AsyncCallbackManagerForToolRun] = None, ) -> str: return await self.llm_chain.apredict( query=query, dialect=self.db.dialect, callbacks=run_manager.get_child() if run_manager else None, )