使用Gradient和LlamaIndex进行Text-to-SQL的微调¶
在这个笔记本中,我们将向您展示如何在sql-create-context数据集上对llama2-7b进行微调,以使其在Text-to-SQL方面表现更好。
我们将使用gradient.ai来实现这一目标。
注意:这是我们关于使用Modal对llama2-7b进行微调的repo/guide的另一种选择:https://github.com/run-llama/modal_finetune_sql
注意:任何Text-to-SQL应用程序都应意识到执行任意SQL查询可能存在安全风险。建议根据需要采取预防措施,例如使用受限角色、只读数据库、沙箱等。
In [ ]:
Copied!
%pip install llama-index-llms-gradient
%pip install llama-index-finetuning
%pip install llama-index-llms-gradient
%pip install llama-index-finetuning
In [ ]:
Copied!
!pip install llama-index gradientai -q
!pip install llama-index gradientai -q
In [ ]:
Copied!
import os
from llama_index.llms.gradient import GradientBaseModelLLM
from llama_index.finetuning import GradientFinetuneEngine
import os
from llama_index.llms.gradient import GradientBaseModelLLM
from llama_index.finetuning import GradientFinetuneEngine
In [ ]:
Copied!
os.environ["GRADIENT_ACCESS_TOKEN"] = os.getenv("GRADIENT_API_KEY")
os.environ["GRADIENT_WORKSPACE_ID"] = ""
os.environ["GRADIENT_ACCESS_TOKEN"] = os.getenv("GRADIENT_API_KEY")
os.environ["GRADIENT_WORKSPACE_ID"] = ""
准备数据¶
我们从Hugging Face数据集中加载sql-create-context数据集。该数据集是WikiSQL和Spider的混合体,以输入查询、上下文和真实的SQL语句的格式进行组织。上下文是一个CREATE TABLE语句。
In [ ]:
Copied!
dialect = "sqlite"
dialect = "sqlite"
加载数据,保存到目录¶
In [ ]:
Copied!
from datasets import load_dataset
from pathlib import Path
import json
def load_jsonl(data_dir):
data_path = Path(data_dir).as_posix()
data = load_dataset("json", data_files=data_path)
return data
def save_jsonl(data_dicts, out_path):
with open(out_path, "w") as fp:
for data_dict in data_dicts:
fp.write(json.dumps(data_dict) + "\n")
def load_data_sql(data_dir: str = "data_sql"):
dataset = load_dataset("b-mc2/sql-create-context")
dataset_splits = {"train": dataset["train"]}
out_path = Path(data_dir)
out_path.parent.mkdir(parents=True, exist_ok=True)
for key, ds in dataset_splits.items():
with open(out_path, "w") as f:
for item in ds:
newitem = {
"input": item["question"],
"context": item["context"],
"output": item["answer"],
}
f.write(json.dumps(newitem) + "\n")
from datasets import load_dataset
from pathlib import Path
import json
def load_jsonl(data_dir):
data_path = Path(data_dir).as_posix()
data = load_dataset("json", data_files=data_path)
return data
def save_jsonl(data_dicts, out_path):
with open(out_path, "w") as fp:
for data_dict in data_dicts:
fp.write(json.dumps(data_dict) + "\n")
def load_data_sql(data_dir: str = "data_sql"):
dataset = load_dataset("b-mc2/sql-create-context")
dataset_splits = {"train": dataset["train"]}
out_path = Path(data_dir)
out_path.parent.mkdir(parents=True, exist_ok=True)
for key, ds in dataset_splits.items():
with open(out_path, "w") as f:
for item in ds:
newitem = {
"input": item["question"],
"context": item["context"],
"output": item["answer"],
}
f.write(json.dumps(newitem) + "\n")
In [ ]:
Copied!
# 将数据转储到data_sql
load_data_sql(data_dir="data_sql")
# 将数据转储到data_sql
load_data_sql(data_dir="data_sql")
分割为训练/验证集¶
In [ ]:
Copied!
from math import ceil
def get_train_val_splits(
data_dir: str = "data_sql",
val_ratio: float = 0.1,
seed: int = 42,
shuffle: bool = True,
):
data = load_jsonl(data_dir)
num_samples = len(data["train"])
val_set_size = ceil(val_ratio * num_samples)
train_val = data["train"].train_test_split(
test_size=val_set_size, shuffle=shuffle, seed=seed
)
return train_val["train"].shuffle(), train_val["test"].shuffle()
from math import ceil
def get_train_val_splits(
data_dir: str = "data_sql",
val_ratio: float = 0.1,
seed: int = 42,
shuffle: bool = True,
):
data = load_jsonl(data_dir)
num_samples = len(data["train"])
val_set_size = ceil(val_ratio * num_samples)
train_val = data["train"].train_test_split(
test_size=val_set_size, shuffle=shuffle, seed=seed
)
return train_val["train"].shuffle(), train_val["test"].shuffle()
In [ ]:
Copied!
raw_train_data, raw_val_data = get_train_val_splits(data_dir="data_sql")
save_jsonl(raw_train_data, "train_data_raw.jsonl")
save_jsonl(raw_val_data, "val_data_raw.jsonl")
raw_train_data, raw_val_data = get_train_val_splits(data_dir="data_sql")
save_jsonl(raw_train_data, "train_data_raw.jsonl")
save_jsonl(raw_val_data, "val_data_raw.jsonl")
Downloading data files: 0%| | 0/1 [00:00<?, ?it/s]
Extracting data files: 0%| | 0/1 [00:00<?, ?it/s]
Generating train split: 0 examples [00:00, ? examples/s]
In [ ]:
Copied!
raw_train_data[0]
raw_train_data[0]
Out[ ]:
{'input': 'If the record is 5-5, what is the game maximum?', 'context': 'CREATE TABLE table_23285805_4 (game INTEGER, record VARCHAR)', 'output': 'SELECT MAX(game) FROM table_23285805_4 WHERE record = "5-5"'}
将训练/数据集字典映射到提示¶
在这里,我们定义函数将数据集字典映射到一个提示格式,然后我们可以将其提供给gradient.ai的微调端点。
In [ ]:
Copied!
### 格式类似于nous-hermes LLMs
text_to_sql_tmpl_str = """\
<s>### 指令:\n{system_message}{user_message}\n\n### 响应:\n{response}</s>"""
text_to_sql_inference_tmpl_str = """\
<s>### 指令:\n{system_message}{user_message}\n\n### 响应:\n"""
### 替代格式
### 推荐使用gradient.ai文档,但我们在实践中发现结果更差
# text_to_sql_tmpl_str = """\
# <s>[INST] SYS\n{system_message}\n<</SYS>>\n\n{user_message} [/INST] {response} </s>"""
# text_to_sql_inference_tmpl_str = """\
# <s>[INST] SYS\n{system_message}\n<</SYS>>\n\n{user_message} [/INST] """
def _generate_prompt_sql(input, context, dialect="sqlite", output=""):
system_message = f"""你是一个强大的文本到SQL模型。你的工作是回答关于数据库的问题。你会得到关于一个或多个表的问题和上下文。
你必须输出能回答问题的SQL查询。
"""
user_message = f"""### 方言:
{dialect}
### 输入:
{input}
### 上下文:
{context}
### 响应:
"""
if output:
return text_to_sql_tmpl_str.format(
system_message=system_message,
user_message=user_message,
response=output,
)
else:
return text_to_sql_inference_tmpl_str.format(
system_message=system_message, user_message=user_message
)
def generate_prompt(data_point):
full_prompt = _generate_prompt_sql(
data_point["input"],
data_point["context"],
dialect="sqlite",
output=data_point["output"],
)
return {"inputs": full_prompt}
### 格式类似于nous-hermes LLMs
text_to_sql_tmpl_str = """\
### 指令:\n{system_message}{user_message}\n\n### 响应:\n{response}"""
text_to_sql_inference_tmpl_str = """\
### 指令:\n{system_message}{user_message}\n\n### 响应:\n"""
### 替代格式
### 推荐使用gradient.ai文档,但我们在实践中发现结果更差
# text_to_sql_tmpl_str = """\
# [INST] SYS\n{system_message}\n<>\n\n{user_message} [/INST] {response} """
# text_to_sql_inference_tmpl_str = """\
# [INST] SYS\n{system_message}\n<>\n\n{user_message} [/INST] """
def _generate_prompt_sql(input, context, dialect="sqlite", output=""):
system_message = f"""你是一个强大的文本到SQL模型。你的工作是回答关于数据库的问题。你会得到关于一个或多个表的问题和上下文。
你必须输出能回答问题的SQL查询。
"""
user_message = f"""### 方言:
{dialect}
### 输入:
{input}
### 上下文:
{context}
### 响应:
"""
if output:
return text_to_sql_tmpl_str.format(
system_message=system_message,
user_message=user_message,
response=output,
)
else:
return text_to_sql_inference_tmpl_str.format(
system_message=system_message, user_message=user_message
)
def generate_prompt(data_point):
full_prompt = _generate_prompt_sql(
data_point["input"],
data_point["context"],
dialect="sqlite",
output=data_point["output"],
)
return {"inputs": full_prompt}
In [ ]:
Copied!
train_data = [
{"inputs": d["inputs"] for d in raw_train_data.map(generate_prompt)}
]
save_jsonl(train_data, "train_data.jsonl")
val_data = [{"inputs": d["inputs"] for d in raw_val_data.map(generate_prompt)}]
save_jsonl(val_data, "val_data.jsonl")
train_data = [
{"inputs": d["inputs"] for d in raw_train_data.map(generate_prompt)}
]
save_jsonl(train_data, "train_data.jsonl")
val_data = [{"inputs": d["inputs"] for d in raw_val_data.map(generate_prompt)}]
save_jsonl(val_data, "val_data.jsonl")
In [ ]:
Copied!
print(train_data[0]["inputs"])
print(train_data[0]["inputs"])
<s>### Instruction: You are a powerful text-to-SQL model. Your job is to answer questions about a database. You are given a question and context regarding one or more tables. You must output the SQL query that answers the question. ### Dialect: sqlite ### Input: Who had the fastest lap in bowmanville, ontario? ### Context: CREATE TABLE table_30134667_2 (fastest_lap VARCHAR, location VARCHAR) ### Response: ### Response: SELECT fastest_lap FROM table_30134667_2 WHERE location = "Bowmanville, Ontario"</s>
使用gradient.ai进行微调¶
在这里,我们使用GradientFinetuneEngine
调用Gradient的微调端点。
为了示例目的,我们限制了步骤,但您可以随意修改参数。
最后,我们获取我们微调后的LLM。
In [ ]:
Copied!
# 基础模型标识 = "nous-hermes2"
base_model_slug = "llama2-7b-chat"
base_llm = GradientBaseModelLLM(
base_model_slug=base_model_slug, max_tokens=300
)
# 基础模型标识 = "nous-hermes2"
base_model_slug = "llama2-7b-chat"
base_llm = GradientBaseModelLLM(
base_model_slug=base_model_slug, max_tokens=300
)
In [ ]:
Copied!
# 步骤 最大步数为20,仅用于测试目的
# 注意:只能指定base_model_slug或model_adapter_id中的一个
finetune_engine = GradientFinetuneEngine(
base_model_slug=base_model_slug,
# model_adapter_id='805c6fd6-daa8-4fc8-a509-bebb2f2c1024_model_adapter',
name="text_to_sql",
data_path="train_data.jsonl",
verbose=True,
max_steps=200,
batch_size=4,
)
# 步骤 最大步数为20,仅用于测试目的
# 注意:只能指定base_model_slug或model_adapter_id中的一个
finetune_engine = GradientFinetuneEngine(
base_model_slug=base_model_slug,
# model_adapter_id='805c6fd6-daa8-4fc8-a509-bebb2f2c1024_model_adapter',
name="text_to_sql",
data_path="train_data.jsonl",
verbose=True,
max_steps=200,
batch_size=4,
)
In [ ]:
Copied!
finetune_engine.model_adapter_id
finetune_engine.model_adapter_id
Out[ ]:
'805c6fd6-daa8-4fc8-a509-bebb2f2c1024_model_adapter'
In [ ]:
Copied!
epochs = 1
for i in range(epochs):
print(f"** EPOCH {i} **")
finetune_engine.finetune()
epochs = 1
for i in range(epochs):
print(f"** EPOCH {i} **")
finetune_engine.finetune()
In [ ]:
Copied!
ft_llm = finetune_engine.get_finetuned_model(max_tokens=300)
ft_llm = finetune_engine.get_finetuned_model(max_tokens=300)
评估¶
这包括两个部分:
- 我们在验证数据集中对一些样本数据点进行评估。
- 我们在一个新的玩具SQL数据集上进行评估,并将经过微调的LLM插入到我们的
NLSQLTableQueryEngine
中,以运行完整的文本到SQL的工作流程。
第一部分:在验证数据集数据点上的评估¶
In [ ]:
Copied!
def get_text2sql_completion(llm, raw_datapoint):
text2sql_tmpl_str = _generate_prompt_sql(
raw_datapoint["input"],
raw_datapoint["context"],
dialect="sqlite",
output=None,
)
response = llm.complete(text2sql_tmpl_str)
return str(response)
def get_text2sql_completion(llm, raw_datapoint):
text2sql_tmpl_str = _generate_prompt_sql(
raw_datapoint["input"],
raw_datapoint["context"],
dialect="sqlite",
output=None,
)
response = llm.complete(text2sql_tmpl_str)
return str(response)
In [ ]:
Copied!
test_datapoint = raw_val_data[2]
display(test_datapoint)
test_datapoint = raw_val_data[2]
display(test_datapoint)
{'input': ' how many\xa0reverse\xa0with\xa0series\xa0being iii series', 'context': 'CREATE TABLE table_12284476_8 (reverse VARCHAR, series VARCHAR)', 'output': 'SELECT COUNT(reverse) FROM table_12284476_8 WHERE series = "III series"'}
In [ ]:
Copied!
# 运行基础llama2-7b-chat模型
get_text2sql_completion(base_llm, test_datapoint)
# 运行基础llama2-7b-chat模型
get_text2sql_completion(base_llm, test_datapoint)
In [ ]:
Copied!
# 运行微调的llama2-7b-chat模型
get_text2sql_completion(ft_llm, test_datapoint)
# 运行微调的llama2-7b-chat模型
get_text2sql_completion(ft_llm, test_datapoint)
Out[ ]:
'SELECT MIN(year) FROM table_name_35 WHERE venue = "barcelona, spain"'
第二部分:对一个玩具数据集进行评估¶
在这里,我们创建了一个包含城市及其人口的玩具数据表。
创建表格¶
In [ ]:
Copied!
# 创建样本
from sqlalchemy import (
create_engine,
MetaData,
Table,
Column,
String,
Integer,
select,
column,
)
from llama_index.core import SQLDatabase
# 创建样本
from sqlalchemy import (
create_engine,
MetaData,
Table,
Column,
String,
Integer,
select,
column,
)
from llama_index.core import SQLDatabase
In [ ]:
Copied!
engine = create_engine("sqlite:///:memory:")
metadata_obj = MetaData()
engine = create_engine("sqlite:///:memory:")
metadata_obj = MetaData()
In [ ]:
Copied!
# 创建城市SQL表
table_name = "city_stats"
city_stats_table = Table(
table_name,
metadata_obj,
Column("city_name", String(16), primary_key=True),
Column("population", Integer),
Column("country", String(16), nullable=False),
)
metadata_obj.create_all(engine)
# 创建城市SQL表
table_name = "city_stats"
city_stats_table = Table(
table_name,
metadata_obj,
Column("city_name", String(16), primary_key=True),
Column("population", Integer),
Column("country", String(16), nullable=False),
)
metadata_obj.create_all(engine)
In [ ]:
Copied!
# 这个上下文稍后会被使用
from sqlalchemy.schema import CreateTable
table_create_stmt = str(CreateTable(city_stats_table))
print(table_create_stmt)
# 这个上下文稍后会被使用
from sqlalchemy.schema import CreateTable
table_create_stmt = str(CreateTable(city_stats_table))
print(table_create_stmt)
CREATE TABLE city_stats ( city_name VARCHAR(16) NOT NULL, population INTEGER, country VARCHAR(16) NOT NULL, PRIMARY KEY (city_name) )
In [ ]:
Copied!
sql_database = SQLDatabase(engine, include_tables=["city_stats"])
sql_database = SQLDatabase(engine, include_tables=["city_stats"])
使用测试数据填充¶
In [ ]:
Copied!
# 插入示例行
from sqlalchemy import insert
rows = [
{"city_name": "多伦多", "population": 2930000, "country": "加拿大"},
{"city_name": "东京", "population": 13960000, "country": "日本"},
{
"city_name": "芝加哥",
"population": 2679000,
"country": "美国",
},
{"city_name": "首尔", "population": 9776000, "country": "韩国"},
]
for row in rows:
stmt = insert(city_stats_table).values(**row)
with engine.connect() as connection:
cursor = connection.execute(stmt)
connection.commit()
# 插入示例行
from sqlalchemy import insert
rows = [
{"city_name": "多伦多", "population": 2930000, "country": "加拿大"},
{"city_name": "东京", "population": 13960000, "country": "日本"},
{
"city_name": "芝加哥",
"population": 2679000,
"country": "美国",
},
{"city_name": "首尔", "population": 9776000, "country": "韩国"},
]
for row in rows:
stmt = insert(city_stats_table).values(**row)
with engine.connect() as connection:
cursor = connection.execute(stmt)
connection.commit()
获取Text2SQL查询引擎¶
In [ ]:
Copied!
from llama_index.core.query_engine import NLSQLTableQueryEngine
from llama_index.core import PromptTemplate
def get_text2sql_query_engine(llm, table_context, sql_database):
# 我们实质上是将现有的模板变量替换为新的模板变量
# 放入我们的 `NLSQLTableQueryEngine` 中
text2sql_tmpl_str = _generate_prompt_sql(
"{query_str}", "{schema}", dialect="{dialect}", output=""
)
sql_prompt = PromptTemplate(text2sql_tmpl_str)
# 在这里,我们明确将表上下文设置为 CREATE TABLE 字符串
# 所以我们将 `tables` 设置为空,并且硬性修复 `context_str` 前缀
query_engine = NLSQLTableQueryEngine(
sql_database,
tables=[],
context_str_prefix=table_context,
text_to_sql_prompt=sql_prompt,
llm=llm,
synthesize_response=False,
)
return query_engine
from llama_index.core.query_engine import NLSQLTableQueryEngine
from llama_index.core import PromptTemplate
def get_text2sql_query_engine(llm, table_context, sql_database):
# 我们实质上是将现有的模板变量替换为新的模板变量
# 放入我们的 `NLSQLTableQueryEngine` 中
text2sql_tmpl_str = _generate_prompt_sql(
"{query_str}", "{schema}", dialect="{dialect}", output=""
)
sql_prompt = PromptTemplate(text2sql_tmpl_str)
# 在这里,我们明确将表上下文设置为 CREATE TABLE 字符串
# 所以我们将 `tables` 设置为空,并且硬性修复 `context_str` 前缀
query_engine = NLSQLTableQueryEngine(
sql_database,
tables=[],
context_str_prefix=table_context,
text_to_sql_prompt=sql_prompt,
llm=llm,
synthesize_response=False,
)
return query_engine
In [ ]:
Copied!
# 查询 = "哪些城市的人口少于1000万人?"
查询 = "东京的人口是多少?(确保城市/国家的名称首字母大写)"
# 查询 = "这些城市的平均人口和总人口是多少?"
# 查询 = "哪些城市的人口少于1000万人?"
查询 = "东京的人口是多少?(确保城市/国家的名称首字母大写)"
# 查询 = "这些城市的平均人口和总人口是多少?"
使用基本llama2模型的结果¶
基本的llama2模型在SQL语句中添加了大量文本,这破坏了我们的解析器(并且有一些小写字母的错误)。
In [ ]:
Copied!
base_query_engine = get_text2sql_query_engine(
base_llm, table_create_stmt, sql_database
)
base_query_engine = get_text2sql_query_engine(
base_llm, table_create_stmt, sql_database
)
In [ ]:
Copied!
base_response = base_query_engine.query(query)
base_response = base_query_engine.query(query)
In [ ]:
Copied!
print(str(base_response))
print(str(base_response))
Error: You can only execute one statement at a time.
In [ ]:
Copied!
base_response.metadata["sql_query"]
base_response.metadata["sql_query"]
Out[ ]:
"SELECT population FROM city_stats WHERE country = 'JAPAN';\n\nThis will return the population of Tokyo, which is the only city in the table with a population value."
经过微调模型的结果¶
In [ ]:
Copied!
ft_query_engine = get_text2sql_query_engine(
ft_llm, table_create_stmt, sql_database
)
ft_query_engine = get_text2sql_query_engine(
ft_llm, table_create_stmt, sql_database
)
In [ ]:
Copied!
ft_response = ft_query_engine.query(query)
ft_response = ft_query_engine.query(query)
In [ ]:
Copied!
print(str(ft_response))
print(str(ft_response))
[(13960000,)]
In [ ]:
Copied!
ft_response.metadata["sql_query"]
ft_response.metadata["sql_query"]
Out[ ]:
'SELECT population FROM city_stats WHERE country = "Japan" AND city_name = "Tokyo"'