Source code for langchain_experimental.sql.base

"""用于与SQL数据库交互的链。"""
from __future__ import annotations

import warnings
from typing import Any, Dict, List, Optional

from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.sql_database.prompt import DECIDER_PROMPT, PROMPT, SQL_PROMPTS
from langchain.schema import BasePromptTemplate
from langchain_community.tools.sql_database.prompt import QUERY_CHECKER
from langchain_community.utilities.sql_database import SQLDatabase
from langchain_core.callbacks.manager import CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts.prompt import PromptTemplate

from langchain_experimental.pydantic_v1 import Extra, Field, root_validator

INTERMEDIATE_STEPS_KEY = "intermediate_steps"
SQL_QUERY = "SQLQuery:"
SQL_RESULT = "SQLResult:"


[docs]class SQLDatabaseChain(Chain): """用于与SQL数据库交互的链。 Example: .. code-block:: python from langchain_experimental.sql import SQLDatabaseChain from langchain_community.llms import OpenAI, SQLDatabase db = SQLDatabase(...) db_chain = SQLDatabaseChain.from_llm(OpenAI(), db) *Security note*: Make sure that the database connection uses credentials that are narrowly-scoped to only include the permissions this chain needs. Failure to do so may result in data corruption or loss, since this chain may attempt commands like `DROP TABLE` or `INSERT` if appropriately prompted. The best way to guard against such negative outcomes is to (as appropriate) limit the permissions granted to the credentials used with this chain. This issue shows an example negative outcome if these steps are not taken: https://github.com/langchain-ai/langchain/issues/5923 """ llm_chain: LLMChain llm: Optional[BaseLanguageModel] = None """[已弃用] 用于使用的LLM包装器。""" database: SQLDatabase = Field(exclude=True) """用于连接的SQL数据库。""" prompt: Optional[BasePromptTemplate] = None """[已弃用] 用于将自然语言翻译成SQL的提示。""" top_k: int = 5 """查询返回的结果数量""" input_key: str = "query" #: :meta private: output_key: str = "result" #: :meta private: return_sql: bool = False """将直接返回SQL命令,而不执行它。""" return_intermediate_steps: bool = False """是否返回中间步骤以及最终答案。""" return_direct: bool = False """是否返回直接查询SQL表的结果。""" use_query_checker: bool = False """查询检查工具是否应该被用来尝试修复来自LLM的初始SQL。""" query_checker_prompt: Optional[BasePromptTemplate] = None """应该被查询检查器使用的提示模板""" class Config: """这是pydantic对象的配置。""" extra = Extra.forbid arbitrary_types_allowed = True @root_validator(pre=True) def raise_deprecation(cls, values: Dict) -> Dict: if "llm" in values: warnings.warn( "Directly instantiating an SQLDatabaseChain with an llm is deprecated. " "Please instantiate with llm_chain argument or using the from_llm " "class method." ) if "llm_chain" not in values and values["llm"] is not None: database = values["database"] prompt = values.get("prompt") or SQL_PROMPTS.get( database.dialect, PROMPT ) values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt) return values @property def input_keys(self) -> List[str]: """返回单个输入键。 :元数据 私有: """ return [self.input_key] @property def output_keys(self) -> List[str]: """返回单个输出键。 :元数据 私有: """ if not self.return_intermediate_steps: return [self.output_key] else: return [self.output_key, INTERMEDIATE_STEPS_KEY] def _call( self, inputs: Dict[str, Any], run_manager: Optional[CallbackManagerForChainRun] = None, ) -> Dict[str, Any]: _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() input_text = f"{inputs[self.input_key]}\n{SQL_QUERY}" _run_manager.on_text(input_text, verbose=self.verbose) # If not present, then defaults to None which is all tables. table_names_to_use = inputs.get("table_names_to_use") table_info = self.database.get_table_info(table_names=table_names_to_use) llm_inputs = { "input": input_text, "top_k": str(self.top_k), "dialect": self.database.dialect, "table_info": table_info, "stop": ["\nSQLResult:"], } if self.memory is not None: for k in self.memory.memory_variables: llm_inputs[k] = inputs[k] intermediate_steps: List = [] try: intermediate_steps.append(llm_inputs.copy()) # input: sql generation sql_cmd = self.llm_chain.predict( callbacks=_run_manager.get_child(), **llm_inputs, ).strip() if self.return_sql: return {self.output_key: sql_cmd} if not self.use_query_checker: _run_manager.on_text(sql_cmd, color="green", verbose=self.verbose) intermediate_steps.append( sql_cmd ) # output: sql generation (no checker) intermediate_steps.append({"sql_cmd": sql_cmd}) # input: sql exec if SQL_QUERY in sql_cmd: sql_cmd = sql_cmd.split(SQL_QUERY)[1].strip() if SQL_RESULT in sql_cmd: sql_cmd = sql_cmd.split(SQL_RESULT)[0].strip() result = self.database.run(sql_cmd) intermediate_steps.append(str(result)) # output: sql exec else: query_checker_prompt = self.query_checker_prompt or PromptTemplate( template=QUERY_CHECKER, input_variables=["query", "dialect"] ) query_checker_chain = LLMChain( llm=self.llm_chain.llm, prompt=query_checker_prompt ) query_checker_inputs = { "query": sql_cmd, "dialect": self.database.dialect, } checked_sql_command: str = query_checker_chain.predict( callbacks=_run_manager.get_child(), **query_checker_inputs ).strip() intermediate_steps.append( checked_sql_command ) # output: sql generation (checker) _run_manager.on_text( checked_sql_command, color="green", verbose=self.verbose ) intermediate_steps.append( {"sql_cmd": checked_sql_command} ) # input: sql exec result = self.database.run(checked_sql_command) intermediate_steps.append(str(result)) # output: sql exec sql_cmd = checked_sql_command _run_manager.on_text("\nSQLResult: ", verbose=self.verbose) _run_manager.on_text(str(result), color="yellow", verbose=self.verbose) # If return direct, we just set the final result equal to # the result of the sql query result, otherwise try to get a human readable # final answer if self.return_direct: final_result = result else: _run_manager.on_text("\nAnswer:", verbose=self.verbose) input_text += f"{sql_cmd}\nSQLResult: {result}\nAnswer:" llm_inputs["input"] = input_text intermediate_steps.append(llm_inputs.copy()) # input: final answer final_result = self.llm_chain.predict( callbacks=_run_manager.get_child(), **llm_inputs, ).strip() intermediate_steps.append(final_result) # output: final answer _run_manager.on_text(final_result, color="green", verbose=self.verbose) chain_result: Dict[str, Any] = {self.output_key: final_result} if self.return_intermediate_steps: chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps return chain_result except Exception as exc: # Append intermediate steps to exception, to aid in logging and later # improvement of few shot prompt seeds exc.intermediate_steps = intermediate_steps # type: ignore raise exc @property def _chain_type(self) -> str: return "sql_database_chain"
[docs] @classmethod def from_llm( cls, llm: BaseLanguageModel, db: SQLDatabase, prompt: Optional[BasePromptTemplate] = None, **kwargs: Any, ) -> SQLDatabaseChain: """从LLM和数据库连接创建一个SQLDatabaseChain。 *安全提示*:确保数据库连接使用的凭据范围狭窄,仅包括此链需要的权限。如果未能这样做,可能会导致数据损坏或丢失,因为此链可能会尝试命令,如`DROP TABLE`或`INSERT`,如果适当提示。防范这种负面结果的最佳方法是(根据情况)限制授予此链使用的凭据的权限。如果不采取这些步骤,这个问题显示了一个负面结果的例子:https://github.com/langchain-ai/langchain/issues/5923 """ prompt = prompt or SQL_PROMPTS.get(db.dialect, PROMPT) llm_chain = LLMChain(llm=llm, prompt=prompt) return cls(llm_chain=llm_chain, database=db, **kwargs)
[docs]class SQLDatabaseSequentialChain(Chain): """用于查询SQL数据库的顺序链。 链条如下: 1. 根据查询确定要使用哪些表。 2. 根据这些表,调用正常的SQL数据库链。 这在数据库中表的数量较大的情况下非常有用。 """ decider_chain: LLMChain sql_chain: SQLDatabaseChain input_key: str = "query" #: :meta private: output_key: str = "result" #: :meta private: return_intermediate_steps: bool = False
[docs] @classmethod def from_llm( cls, llm: BaseLanguageModel, db: SQLDatabase, query_prompt: BasePromptTemplate = PROMPT, decider_prompt: BasePromptTemplate = DECIDER_PROMPT, **kwargs: Any, ) -> SQLDatabaseSequentialChain: """加载必要的链。""" sql_chain = SQLDatabaseChain.from_llm(llm, db, prompt=query_prompt, **kwargs) decider_chain = LLMChain( llm=llm, prompt=decider_prompt, output_key="table_names" ) return cls(sql_chain=sql_chain, decider_chain=decider_chain, **kwargs)
@property def input_keys(self) -> List[str]: """返回单个输入键。 :元数据 私有: """ return [self.input_key] @property def output_keys(self) -> List[str]: """返回单个输出键。 :元数据 私有: """ if not self.return_intermediate_steps: return [self.output_key] else: return [self.output_key, INTERMEDIATE_STEPS_KEY] def _call( self, inputs: Dict[str, Any], run_manager: Optional[CallbackManagerForChainRun] = None, ) -> Dict[str, Any]: _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() _table_names = self.sql_chain.database.get_usable_table_names() table_names = ", ".join(_table_names) llm_inputs = { "query": inputs[self.input_key], "table_names": table_names, } _lowercased_table_names = [name.lower() for name in _table_names] table_names_from_chain = self.decider_chain.predict_and_parse(**llm_inputs) table_names_to_use = [ name for name in table_names_from_chain if name.lower() in _lowercased_table_names ] _run_manager.on_text("Table names to use:", end="\n", verbose=self.verbose) _run_manager.on_text( str(table_names_to_use), color="yellow", verbose=self.verbose ) new_inputs = { self.sql_chain.input_key: inputs[self.input_key], "table_names_to_use": table_names_to_use, } return self.sql_chain( new_inputs, callbacks=_run_manager.get_child(), return_only_outputs=True ) @property def _chain_type(self) -> str: return "sql_database_sequential_chain"