Source code for langchain_community.tools.spark_sql.tool

# flake8: noqa
"""工具用于与Spark SQL交互。"""
from typing import Any, Dict, Optional

from langchain_core.pydantic_v1 import BaseModel, Field, root_validator

from langchain_core.language_models import BaseLanguageModel
from langchain_core.callbacks import (
    AsyncCallbackManagerForToolRun,
    CallbackManagerForToolRun,
)
from langchain_core.prompts import PromptTemplate
from langchain_community.utilities.spark_sql import SparkSQL
from langchain_core.tools import BaseTool
from langchain_community.tools.spark_sql.prompt import QUERY_CHECKER


[docs]class BaseSparkSQLTool(BaseModel): """用于与Spark SQL交互的基本工具。""" db: SparkSQL = Field(exclude=True) class Config(BaseTool.Config): pass
[docs]class QuerySparkSQLTool(BaseSparkSQLTool, BaseTool): """用于查询Spark SQL的工具。""" name: str = "query_sql_db" description: str = """ Input to this tool is a detailed and correct SQL query, output is a result from the Spark SQL. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. """ def _run( self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None, ) -> str: """执行查询,返回结果或错误消息。""" return self.db.run_no_throw(query)
[docs]class InfoSparkSQLTool(BaseSparkSQLTool, BaseTool): """用于获取有关Spark SQL元数据的工具。""" name: str = "schema_sql_db" description: str = """ Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling list_tables_sql_db first! Example Input: "table1, table2, table3" """ def _run( self, table_names: str, run_manager: Optional[CallbackManagerForToolRun] = None, ) -> str: """获取以逗号分隔的表的模式。""" return self.db.get_table_info_no_throw(table_names.split(", "))
[docs]class ListSparkSQLTool(BaseSparkSQLTool, BaseTool): """获取表名的工具。""" name: str = "list_tables_sql_db" description: str = "Input is an empty string, output is a comma separated list of tables in the Spark SQL." def _run( self, tool_input: str = "", run_manager: Optional[CallbackManagerForToolRun] = None, ) -> str: """获取特定表的模式。""" return ", ".join(self.db.get_usable_table_names())
[docs]class QueryCheckerTool(BaseSparkSQLTool, BaseTool): """使用LLM来检查查询是否正确。 改编自https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/""" template: str = QUERY_CHECKER llm: BaseLanguageModel llm_chain: Any = Field(init=False) name: str = "query_checker_sql_db" description: str = """ Use this tool to double check if your query is correct before executing it. Always use this tool before executing a query with query_sql_db! """ @root_validator(pre=True) def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]: if "llm_chain" not in values: from langchain.chains.llm import LLMChain values["llm_chain"] = LLMChain( llm=values.get("llm"), # type: ignore[arg-type] prompt=PromptTemplate( template=QUERY_CHECKER, input_variables=["query"] ), ) if values["llm_chain"].prompt.input_variables != ["query"]: raise ValueError( "LLM chain for QueryCheckerTool need to use ['query'] as input_variables " "for the embedded prompt" ) return values def _run( self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None, ) -> str: """使用LLM来检查查询。""" return self.llm_chain.predict( query=query, callbacks=run_manager.get_child() if run_manager else None ) async def _arun( self, query: str, run_manager: Optional[AsyncCallbackManagerForToolRun] = None, ) -> str: return await self.llm_chain.apredict( query=query, callbacks=run_manager.get_child() if run_manager else None )