Text-to-SQL指南(查询引擎+检索器)¶
这是关于LlamaIndex的Text-to-SQL功能的基本指南。
- 首先,我们将展示如何在一个玩具数据集上执行文本到SQL转换:这将进行“检索”(对数据库进行SQL查询)和“合成”。
- 然后,我们将展示如何构建一个TableIndex来动态检索相关表格以供查询时使用。
- 最后,我们将向您展示如何单独定义一个文本到SQL检索器。
注意: 任何Text-to-SQL应用程序都应意识到执行任意SQL查询可能存在安全风险。建议采取必要的预防措施,例如使用受限角色、只读数据库、沙箱等。
如果您在colab上打开这个笔记本,您可能需要安装LlamaIndex 🦙。
%pip install llama-index-llms-openai
!pip install llama-index
import os
import openai
os.environ["OPENAI_API_KEY"] = "sk-.."
openai.api_key = os.environ["OPENAI_API_KEY"]
# 导入日志模块
# 导入系统模块
# 配置日志基本设置,输出到标准输出流,日志级别为INFO
# 获取日志记录器并添加一个输出到标准输出流的处理器
from IPython.display import Markdown, display
创建数据库模式¶
我们使用 sqlalchemy
,一个流行的SQL数据库工具包,来创建一个空的 city_stats
表。
from sqlalchemy import (
create_engine,
MetaData,
Table,
Column,
String,
Integer,
select,
)
engine = create_engine("sqlite:///:memory:")
metadata_obj = MetaData()
# 创建城市SQL表
table_name = "city_stats"
city_stats_table = Table(
table_name,
metadata_obj,
Column("city_name", String(16), primary_key=True),
Column("population", Integer),
Column("country", String(16), nullable=False),
)
metadata_obj.create_all(engine)
定义SQL数据库¶
我们首先定义我们的SQLDatabase
抽象(它是对SQLAlchemy的轻量级封装)。
from llama_index.core import SQLDatabase
from llama_index.llms.openai import OpenAI
llm = OpenAI(temperature=0.1, model="gpt-3.5-turbo")
sql_database = SQLDatabase(engine, include_tables=["city_stats"])
我们向我们的SQL数据库中添加了一些测试数据。
sql_database = SQLDatabase(engine, include_tables=["city_stats"])
from sqlalchemy import insert
rows = [
{"city_name": "Toronto", "population": 2930000, "country": "Canada"},
{"city_name": "Tokyo", "population": 13960000, "country": "Japan"},
{
"city_name": "Chicago",
"population": 2679000,
"country": "United States",
},
{"city_name": "Seoul", "population": 9776000, "country": "South Korea"},
]
for row in rows:
stmt = insert(city_stats_table).values(**row)
with engine.begin() as connection:
cursor = connection.execute(stmt)
# 查看当前表格
stmt = select(
city_stats_table.c.city_name,
city_stats_table.c.population,
city_stats_table.c.country,
).select_from(city_stats_table)
with engine.connect() as connection:
results = connection.execute(stmt).fetchall()
print(results)
[('Toronto', 2930000, 'Canada'), ('Tokyo', 13960000, 'Japan'), ('Chicago', 2679000, 'United States'), ('Seoul', 9776000, 'South Korea')]
查询索引¶
我们首先展示如何执行原始的SQL查询,直接在表上执行。
from sqlalchemy import text
with engine.connect() as con:
rows = con.execute(text("SELECT city_name from city_stats"))
for row in rows:
print(row)
('Chicago',) ('Seoul',) ('Tokyo',) ('Toronto',)
第1部分:文本到SQL查询引擎¶
一旦我们构建了SQL数据库,我们就可以使用NLSQLTableQueryEngine来构建自然语言查询,这些查询会被合成为SQL查询。
请注意,我们需要指定要在此查询引擎中使用的表。 如果我们不指定,查询引擎将提取所有的模式上下文,这可能会超出LLM的上下文窗口。
from llama_index.core.query_engine import NLSQLTableQueryEngine
query_engine = NLSQLTableQueryEngine(
sql_database=sql_database, tables=["city_stats"], llm=llm
)
query_str = "Which city has the highest population?"
response = query_engine.query(query_str)
display(Markdown(f"<b>{response}</b>"))
The city with the highest population is Tokyo.
这个查询引擎应该在任何你可以预先指定要查询的表,或者所有表模式的总大小加上其余提示符适合你的上下文窗口的情况下使用。
第二部分:文本到SQL的查询时间检索表格¶
如果我们事先不知道要使用哪个表,并且表模式的总大小超出了上下文窗口的大小,那么我们应该将表模式存储在索引中,这样在查询时我们就可以检索到正确的模式。
我们可以使用SQLTableNodeMapping对象来实现这一点,它接受一个SQLDatabase并为传递给ObjectIndex构造函数的每个SQLTableSchema对象生成一个Node对象。
from llama_index.core.indices.struct_store.sql_query import (
SQLTableRetrieverQueryEngine,
)
from llama_index.core.objects import (
SQLTableNodeMapping,
ObjectIndex,
SQLTableSchema,
)
from llama_index.core import VectorStoreIndex
# 将日志级别设置为DEBUG以获得更详细的输出
table_node_mapping = SQLTableNodeMapping(sql_database)
table_schema_objs = [
(SQLTableSchema(table_name="city_stats"))
] # 为每个表添加一个SQLTableSchema
obj_index = ObjectIndex.from_objects(
table_schema_objs,
table_node_mapping,
VectorStoreIndex,
)
query_engine = SQLTableRetrieverQueryEngine(
sql_database, obj_index.as_retriever(similarity_top_k=1)
)
现在我们可以使用SQLTableRetrieverQueryEngine来查询并获取我们的响应。
response = query_engine.query("Which city has the highest population?")
display(Markdown(f"<b>{response}</b>"))
The city with the highest population is Tokyo.
# 你也可以从SQLAlchemy中获取原始结果!
response.metadata["result"]
[('Tokyo',)]
您还可以为您定义的每个表模式添加额外的上下文信息。
# 手动设置上下文文本
city_stats_text = (
"该表提供了有关给定城市的人口和国家的信息。\n用户将使用代码词进行查询,其中'foo'对应人口,'bar'对应城市。"
)
table_node_mapping = SQLTableNodeMapping(sql_database)
table_schema_objs = [
(SQLTableSchema(table_name="city_stats", context_str=city_stats_text))
]
第三部分:文本到SQL检索器¶
到目前为止,我们的文本到SQL功能被打包在一个查询引擎中,包括检索和合成两部分。
你可以单独使用SQL检索器。我们将向你展示一些不同的参数可以尝试,并且还会展示如何将其插入到我们的RetrieverQueryEngine
中,以获得大致相同的结果。
from llama_index.core.retrievers import NLSQLRetriever
# 默认检索(return_raw=True)
nl_sql_retriever = NLSQLRetriever(
sql_database, tables=["city_stats"], return_raw=True
)
results = nl_sql_retriever.retrieve(
"Return the top 5 cities (along with their populations) with the highest population."
)
from llama_index.core.response.notebook_utils import display_source_node
for n in results:
display_source_node(n)
Node ID: 458f723e-f1ac-4423-917a-522a71763390
Similarity: None
Text: [('Tokyo', 13960000), ('Seoul', 9776000), ('Toronto', 2930000), ('Chicago', 2679000)]
# 默认检索(return_raw=False)
nl_sql_retriever = NLSQLRetriever(
sql_database, tables=["city_stats"], return_raw=False
)
results = nl_sql_retriever.retrieve(
"Return the top 5 cities (along with their populations) with the highest population."
)
# 注意:所有内容都在元数据中
对于结果中的每个n:
显示源节点(n, show_source_metadata=True)
Node ID: 7c0e4c94-c9a6-4917-aa3f-e3b3f4cbcd5c
Similarity: None
Text:
Metadata: {'city_name': 'Tokyo', 'population': 13960000}
Node ID: 3c1d1caa-cec2-451e-8fd1-adc944e1d050
Similarity: None
Text:
Metadata: {'city_name': 'Seoul', 'population': 9776000}
Node ID: fb9f9b25-b913-4dde-a0e3-6111f704aea9
Similarity: None
Text:
Metadata: {'city_name': 'Toronto', 'population': 2930000}
Node ID: c31ba8e7-de5d-4f28-a464-5e0339547c70
Similarity: None
Text:
Metadata: {'city_name': 'Chicago', 'population': 2679000}
插入我们的RetrieverQueryEngine
¶
我们将我们的SQL Retriever与标准的RetrieverQueryEngine
组合,以合成一个响应。结果大致类似于我们打包的Text-to-SQL
查询引擎。
from llama_index.core.query_engine import RetrieverQueryEngine
query_engine = RetrieverQueryEngine.from_args(nl_sql_retriever)
response = query_engine.query(
"Return the top 5 cities (along with their populations) with the highest population."
)
print(str(response))
The top 5 cities with the highest population are: 1. Tokyo - 13,960,000 2. Seoul - 9,776,000 3. Toronto - 2,930,000 4. Chicago - 2,679,000