Source code for langchain.chains.sql_database.query
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypedDict, Union
from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import BasePromptTemplate
from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain.chains.sql_database.prompt import PROMPT, SQL_PROMPTS
if TYPE_CHECKING:
from langchain_community.utilities.sql_database import SQLDatabase
def _strip(text: str) -> str:
return text.strip()
[docs]def create_sql_query_chain(
llm: BaseLanguageModel,
db: SQLDatabase,
prompt: Optional[BasePromptTemplate] = None,
k: int = 5,
) -> Runnable[Union[SQLInput, SQLInputWithTables, Dict[str, Any]], str]:
"""创建一个生成SQL查询的链。
*安全提示*: 此链为给定数据库生成SQL查询。
SQLDatabase类提供了一个get_table_info方法,可用于获取表的列信息以及示例数据。
为了减少泄露敏感数据的风险,限制权限为只读,并范围限定在所需的表中。
可选地,使用SQLInputWithTables输入类型指定允许访问的表。
控制谁可以向此链提交请求的访问权限。
有关更多信息,请参见 https://python.langchain.com/docs/security。
参数:
llm: 要使用的语言模型。
db: 用于生成查询的SQLDatabase。
prompt: 要使用的提示。如果未提供任何提示,将根据方言选择一个。默认为None。有关更多信息,请参见下面的Prompt部分。
k: 每个select语句返回的结果数量。默认为5。
返回:
一个接受问题并生成SQL查询以回答该问题的链。
示例:
.. code-block:: python
# pip install -U langchain langchain-community langchain-openai
from langchain_openai import ChatOpenAI
from langchain.chains import create_sql_query_chain
from langchain_community.utilities import SQLDatabase
db = SQLDatabase.from_uri("sqlite:///Chinook.db")
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
chain = create_sql_query_chain(llm, db)
response = chain.invoke({"question": "How many employees are there"})
提示:
如果未提供提示,则根据SQLDatabase方言选择默认提示。如果提供了提示,则必须支持输入变量:
* input: 用户问题加后缀"SQLQuery: "传递到这里。
* top_k: 每个select语句返回的结果数量(传递给此函数的`k`参数)在此处传递。
* table_info: 表定义和示例行在此处传递。如果用户在调用链时指定了"table_names_to_use",则只包括这些表。否则,将包括所有表。
* dialect(可选): 如果提示中包含方言输入变量,则db方言将传递到这里。
这是一个示例提示:
.. code-block:: python
from langchain_core.prompts import PromptTemplate
template = '''给定一个输入问题,首先创建一个语法正确的{dialect}查询来运行,然后查看查询的结果并返回答案。
使用以下格式:
问题: "这里是问题"
SQL查询: "要运行的SQL查询"
SQL结果: "SQL查询的结果"
答案: "这里是最终答案"
仅使用以下表:
{table_info}.
问题: {input}'''
prompt = PromptTemplate.from_template(template)
""" # noqa: E501
if prompt is not None:
prompt_to_use = prompt
elif db.dialect in SQL_PROMPTS:
prompt_to_use = SQL_PROMPTS[db.dialect]
else:
prompt_to_use = PROMPT
if {"input", "top_k", "table_info"}.difference(prompt_to_use.input_variables):
raise ValueError(
f"Prompt must have input variables: 'input', 'top_k', "
f"'table_info'. Received prompt with input variables: "
f"{prompt_to_use.input_variables}. Full prompt:\n\n{prompt_to_use}"
)
if "dialect" in prompt_to_use.input_variables:
prompt_to_use = prompt_to_use.partial(dialect=db.dialect)
inputs = {
"input": lambda x: x["question"] + "\nSQLQuery: ",
"table_info": lambda x: db.get_table_info(
table_names=x.get("table_names_to_use")
),
}
return (
RunnablePassthrough.assign(**inputs) # type: ignore
| (
lambda x: {
k: v
for k, v in x.items()
if k not in ("question", "table_names_to_use")
}
)
| prompt_to_use.partial(top_k=str(k))
| llm.bind(stop=["\nSQLResult:"])
| StrOutputParser()
| _strip
)