Source code for langchain.chains.sql_database.query

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypedDict, Union

from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import BasePromptTemplate
from langchain_core.runnables import Runnable, RunnablePassthrough

from langchain.chains.sql_database.prompt import PROMPT, SQL_PROMPTS

if TYPE_CHECKING:
    from langchain_community.utilities.sql_database import SQLDatabase


def _strip(text: str) -> str:
    return text.strip()


[docs]class SQLInput(TypedDict): """SQL链的输入。""" question: str
[docs]class SQLInputWithTables(TypedDict): """SQL链的输入。""" question: str table_names_to_use: List[str]
[docs]def create_sql_query_chain( llm: BaseLanguageModel, db: SQLDatabase, prompt: Optional[BasePromptTemplate] = None, k: int = 5, ) -> Runnable[Union[SQLInput, SQLInputWithTables, Dict[str, Any]], str]: """创建一个生成SQL查询的链。 *安全提示*: 此链为给定数据库生成SQL查询。 SQLDatabase类提供了一个get_table_info方法,可用于获取表的列信息以及示例数据。 为了减少泄露敏感数据的风险,限制权限为只读,并范围限定在所需的表中。 可选地,使用SQLInputWithTables输入类型指定允许访问的表。 控制谁可以向此链提交请求的访问权限。 有关更多信息,请参见 https://python.langchain.com/docs/security。 参数: llm: 要使用的语言模型。 db: 用于生成查询的SQLDatabase。 prompt: 要使用的提示。如果未提供任何提示,将根据方言选择一个。默认为None。有关更多信息,请参见下面的Prompt部分。 k: 每个select语句返回的结果数量。默认为5。 返回: 一个接受问题并生成SQL查询以回答该问题的链。 示例: .. code-block:: python # pip install -U langchain langchain-community langchain-openai from langchain_openai import ChatOpenAI from langchain.chains import create_sql_query_chain from langchain_community.utilities import SQLDatabase db = SQLDatabase.from_uri("sqlite:///Chinook.db") llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0) chain = create_sql_query_chain(llm, db) response = chain.invoke({"question": "How many employees are there"}) 提示: 如果未提供提示,则根据SQLDatabase方言选择默认提示。如果提供了提示,则必须支持输入变量: * input: 用户问题加后缀"SQLQuery: "传递到这里。 * top_k: 每个select语句返回的结果数量(传递给此函数的`k`参数)在此处传递。 * table_info: 表定义和示例行在此处传递。如果用户在调用链时指定了"table_names_to_use",则只包括这些表。否则,将包括所有表。 * dialect(可选): 如果提示中包含方言输入变量,则db方言将传递到这里。 这是一个示例提示: .. code-block:: python from langchain_core.prompts import PromptTemplate template = '''给定一个输入问题,首先创建一个语法正确的{dialect}查询来运行,然后查看查询的结果并返回答案。 使用以下格式: 问题: "这里是问题" SQL查询: "要运行的SQL查询" SQL结果: "SQL查询的结果" 答案: "这里是最终答案" 仅使用以下表: {table_info}. 问题: {input}''' prompt = PromptTemplate.from_template(template) """ # noqa: E501 if prompt is not None: prompt_to_use = prompt elif db.dialect in SQL_PROMPTS: prompt_to_use = SQL_PROMPTS[db.dialect] else: prompt_to_use = PROMPT if {"input", "top_k", "table_info"}.difference(prompt_to_use.input_variables): raise ValueError( f"Prompt must have input variables: 'input', 'top_k', " f"'table_info'. Received prompt with input variables: " f"{prompt_to_use.input_variables}. Full prompt:\n\n{prompt_to_use}" ) if "dialect" in prompt_to_use.input_variables: prompt_to_use = prompt_to_use.partial(dialect=db.dialect) inputs = { "input": lambda x: x["question"] + "\nSQLQuery: ", "table_info": lambda x: db.get_table_info( table_names=x.get("table_names_to_use") ), } return ( RunnablePassthrough.assign(**inputs) # type: ignore | ( lambda x: { k: v for k, v in x.items() if k not in ("question", "table_names_to_use") } ) | prompt_to_use.partial(top_k=str(k)) | llm.bind(stop=["\nSQLResult:"]) | StrOutputParser() | _strip )