[Beta] 使用PGVector进行文本到SQL的转换¶
这个笔记本演示了如何使用pgvector执行文本到SQL的转换。这使我们能够在SQL中同时进行语义搜索和结构化查询!
这理论上比语义搜索+元数据过滤能够实现更具表达力的查询。
注意:这是一个测试功能,接口可能会发生变化。但与此同时,希望您会发现它有用!
注意:任何文本到SQL的应用都应该意识到执行任意SQL查询可能存在安全风险。建议采取必要的预防措施,比如使用受限角色、只读数据库、沙盒等。
设置数据¶
加载文档¶
加载Lyft 2021年的10k文件。
In [ ]:
Copied!
%pip install llama-index-embeddings-huggingface
%pip install llama-index-readers-file
%pip install llama-index-llms-openai
%pip install llama-index-embeddings-huggingface
%pip install llama-index-readers-file
%pip install llama-index-llms-openai
In [ ]:
Copied!
from llama_index.readers.file import PDFReader
from llama_index.readers.file import PDFReader
In [ ]:
Copied!
reader = PDFReader()
reader = PDFReader()
下载数据
In [ ]:
Copied!
!mkdir -p 'data/10k/'
!wget 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/docs/examples/data/10k/lyft_2021.pdf' -O 'data/10k/lyft_2021.pdf'
!mkdir -p 'data/10k/'
!wget 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/docs/examples/data/10k/lyft_2021.pdf' -O 'data/10k/lyft_2021.pdf'
In [ ]:
Copied!
docs = reader.load_data("./data/10k/lyft_2021.pdf")
docs = reader.load_data("./data/10k/lyft_2021.pdf")
In [ ]:
Copied!
from llama_index.core.node_parser import SentenceSplitter
node_parser = SentenceSplitter()
nodes = node_parser.get_nodes_from_documents(docs)
from llama_index.core.node_parser import SentenceSplitter
node_parser = SentenceSplitter()
nodes = node_parser.get_nodes_from_documents(docs)
In [ ]:
Copied!
print(nodes[8].get_content(metadata_mode="all"))
print(nodes[8].get_content(metadata_mode="all"))
将数据插入到Postgres + PGVector¶
确保已安装所有必要的依赖项!
In [ ]:
Copied!
!pip install psycopg2-binary pgvector asyncpg "sqlalchemy[asyncio]" greenlet
!pip install psycopg2-binary pgvector asyncpg "sqlalchemy[asyncio]" greenlet
In [ ]:
Copied!
from pgvector.sqlalchemy import Vector
from sqlalchemy import insert, create_engine, String, text, Integer
from sqlalchemy.orm import declarative_base, mapped_column
from pgvector.sqlalchemy import Vector
from sqlalchemy import insert, create_engine, String, text, Integer
from sqlalchemy.orm import declarative_base, mapped_column
# 建立连接
In [ ]:
Copied!
engine = create_engine("postgresql+psycopg2://localhost/postgres")
with engine.connect() as conn:
conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
conn.commit()
engine = create_engine("postgresql+psycopg2://localhost/postgres")
with engine.connect() as conn:
conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
conn.commit()
定义表结构¶
定义为Python类。注意我们存储页面标签、嵌入和文本。
In [ ]:
Copied!
Base = declarative_base()
class SECTextChunk(Base):
__tablename__ = "sec_text_chunk"
id = mapped_column(Integer, primary_key=True)
page_label = mapped_column(Integer)
file_name = mapped_column(String)
text = mapped_column(String)
embedding = mapped_column(Vector(384))
Base = declarative_base()
class SECTextChunk(Base):
__tablename__ = "sec_text_chunk"
id = mapped_column(Integer, primary_key=True)
page_label = mapped_column(Integer)
file_name = mapped_column(String)
text = mapped_column(String)
embedding = mapped_column(Vector(384))
In [ ]:
Copied!
Base.metadata.drop_all(engine)
Base.metadata.create_all(engine)
Base.metadata.drop_all(engine)
Base.metadata.create_all(engine)
使用sentence_transformers模型为每个节点生成嵌入向量¶
In [ ]:
Copied!
# 为每一行获取嵌入
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en")
for node in nodes:
text_embedding = embed_model.get_text_embedding(node.get_content())
node.embedding = text_embedding
# 为每一行获取嵌入
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en")
for node in nodes:
text_embedding = embed_model.get_text_embedding(node.get_content())
node.embedding = text_embedding
插入数据库¶
In [ ]:
Copied!
# 插入到数据库
for node in nodes:
row_dict = {
"text": node.get_content(),
"embedding": node.embedding,
**node.metadata,
}
stmt = insert(SECTextChunk).values(**row_dict)
with engine.connect() as connection:
cursor = connection.execute(stmt)
connection.commit()
# 插入到数据库
for node in nodes:
row_dict = {
"text": node.get_content(),
"embedding": node.embedding,
**node.metadata,
}
stmt = insert(SECTextChunk).values(**row_dict)
with engine.connect() as connection:
cursor = connection.execute(stmt)
connection.commit()
定义PGVectorSQLQueryEngine¶
现在我们已经将数据加载到数据库中,我们准备设置我们的查询引擎。
定义提示¶
我们创建了一个修改过的默认文本到SQL提示的版本,以注入对pgvector语法的认识。 我们还使用一些少量示例来提示如何使用这种语法(<-->)。
注意:这在PGVectorSQLQueryEngine
中默认包含,我们在这里主要是为了让大家看到!
In [ ]:
Copied!
from llama_index.core import PromptTemplate
text_to_sql_tmpl = """\
给定一个输入问题,首先创建一个语法正确的{dialect}查询来运行,然后查看查询的结果并返回答案。您可以按相关列对结果进行排序,以返回数据库中最有趣的示例。
注意只使用模式描述中可见的列名。小心不要查询不存在的列。注意哪个列在哪个表中。此外,需要时用表名限定列名。
重要提示:您可以使用专门的pgvector语法(`<->`)来对表中的嵌入列进行最近邻/语义搜索到给定向量。给定行的嵌入值通常表示该行的语义含义。向量表示问题的嵌入表示,如下所示。不要直接填写向量值,而是指定一个`[query_vector]`占位符。例如,下面是一些选择语句示例(嵌入列的名称为`embedding`):
SELECT * FROM items ORDER BY embedding <-> '[query_vector]' LIMIT 5;
SELECT * FROM items WHERE id != 1 ORDER BY embedding <-> (SELECT embedding FROM items WHERE id = 1) LIMIT 5;
SELECT * FROM items WHERE embedding <-> '[query_vector]' < 5;
您需要使用以下格式,每行一个:
问题:问题在这里
SQL查询:要运行的SQL查询
SQL结果:SQL查询的结果
答案:最终答案在这里
只使用下面列出的表。
{schema}
问题:{query_str}
SQL查询:\
"""
text_to_sql_prompt = PromptTemplate(text_to_sql_tmpl)
from llama_index.core import PromptTemplate
text_to_sql_tmpl = """\
给定一个输入问题,首先创建一个语法正确的{dialect}查询来运行,然后查看查询的结果并返回答案。您可以按相关列对结果进行排序,以返回数据库中最有趣的示例。
注意只使用模式描述中可见的列名。小心不要查询不存在的列。注意哪个列在哪个表中。此外,需要时用表名限定列名。
重要提示:您可以使用专门的pgvector语法(`<->`)来对表中的嵌入列进行最近邻/语义搜索到给定向量。给定行的嵌入值通常表示该行的语义含义。向量表示问题的嵌入表示,如下所示。不要直接填写向量值,而是指定一个`[query_vector]`占位符。例如,下面是一些选择语句示例(嵌入列的名称为`embedding`):
SELECT * FROM items ORDER BY embedding <-> '[query_vector]' LIMIT 5;
SELECT * FROM items WHERE id != 1 ORDER BY embedding <-> (SELECT embedding FROM items WHERE id = 1) LIMIT 5;
SELECT * FROM items WHERE embedding <-> '[query_vector]' < 5;
您需要使用以下格式,每行一个:
问题:问题在这里
SQL查询:要运行的SQL查询
SQL结果:SQL查询的结果
答案:最终答案在这里
只使用下面列出的表。
{schema}
问题:{query_str}
SQL查询:\
"""
text_to_sql_prompt = PromptTemplate(text_to_sql_tmpl)
设置LLM、嵌入模型和其他杂项。¶
除了LLM和嵌入模型之外,注意我们还在表格本身上添加了注释。这有助于LLM更好地理解列模式(例如,通过告诉它嵌入列代表什么)以更好地进行表格查询或语义搜索。
In [ ]:
Copied!
from llama_index.core import SQLDatabase
from llama_index.llms.openai import OpenAI
from llama_index.core.query_engine import PGVectorSQLQueryEngine
from llama_index.core import Settings
sql_database = SQLDatabase(engine, include_tables=["sec_text_chunk"])
Settings.llm = OpenAI(model="gpt-4")
Settings.embed_model = embed_model
table_desc = """\
这个表代表来自SEC报告的文本块。每一行包含以下列:
id: 行的id
page_label: 页码
file_name: 顶层文件名
text: 所有文本块都在这里
embedding: 代表文本块的嵌入
对于大多数查询,您应该针对`embedding`列的值执行语义搜索,因为它编码了文本的含义。
"""
context_query_kwargs = {"sec_text_chunk": table_desc}
from llama_index.core import SQLDatabase
from llama_index.llms.openai import OpenAI
from llama_index.core.query_engine import PGVectorSQLQueryEngine
from llama_index.core import Settings
sql_database = SQLDatabase(engine, include_tables=["sec_text_chunk"])
Settings.llm = OpenAI(model="gpt-4")
Settings.embed_model = embed_model
table_desc = """\
这个表代表来自SEC报告的文本块。每一行包含以下列:
id: 行的id
page_label: 页码
file_name: 顶层文件名
text: 所有文本块都在这里
embedding: 代表文本块的嵌入
对于大多数查询,您应该针对`embedding`列的值执行语义搜索,因为它编码了文本的含义。
"""
context_query_kwargs = {"sec_text_chunk": table_desc}
定义查询引擎¶
In [ ]:
Copied!
query_engine = PGVectorSQLQueryEngine(
sql_database=sql_database,
text_to_sql_prompt=text_to_sql_prompt,
context_query_kwargs=context_query_kwargs,
)
query_engine = PGVectorSQLQueryEngine(
sql_database=sql_database,
text_to_sql_prompt=text_to_sql_prompt,
context_query_kwargs=context_query_kwargs,
)
运行一些查询¶
现在我们已经准备好运行一些查询。
In [ ]:
Copied!
response = query_engine.query(
"Can you tell me about the risk factors described in page 6?",
)
response = query_engine.query(
"Can you tell me about the risk factors described in page 6?",
)
In [ ]:
Copied!
print(str(response))
print(str(response))
Page 6 discusses the impact of the COVID-19 pandemic on the business. It mentions that the pandemic has affected communities in the United States, Canada, and globally. The pandemic has led to a significant decrease in the demand for ridesharing services, which has negatively impacted the company's financial performance. The page also discusses the company's efforts to adapt to the changing environment by focusing on the delivery of essential goods and services. Additionally, it mentions the company's transportation network, which offers riders seamless, personalized, and on-demand access to a variety of mobility options.
In [ ]:
Copied!
print(response.metadata["sql_query"])
print(response.metadata["sql_query"])
In [ ]:
Copied!
response = query_engine.query(
"Tell me more about Lyft's real estate operating leases",
)
response = query_engine.query(
"Tell me more about Lyft's real estate operating leases",
)
In [ ]:
Copied!
print(str(response))
print(str(response))
Lyft's lease arrangements include vehicle rental programs, office space, and data centers. Leases that do not meet any specific criteria are accounted for as operating leases. The lease term begins when Lyft is available to use the underlying asset and ends upon the termination of the lease. The lease term includes any periods covered by an option to extend if Lyft is reasonably certain to exercise that option. Leasehold improvements are amortized on a straight-line basis over the shorter of the term of the lease, or the useful life of the assets.
In [ ]:
Copied!
print(response.metadata["sql_query"][:300])
print(response.metadata["sql_query"][:300])
SELECT * FROM sec_text_chunk WHERE text LIKE '%Lyft%' AND text LIKE '%real estate%' AND text LIKE '%operating leases%' ORDER BY embedding <-> '[-0.007079003844410181, -0.04383348673582077, 0.02910166047513485, 0.02049737051129341, 0.009460929781198502, -0.017539210617542267, 0.04225028306245804, 0.0
In [ ]:
Copied!
# 查看返回的结果
print(response.metadata["result"])
# 查看返回的结果
print(response.metadata["result"])
[(157, 93, 'lyft_2021.pdf', "Leases that do not meet any of the above criteria are accounted for as operating leases.Lessor\nThe\n Company's lease arrangements include vehicle re ... (4356 characters truncated) ... realized. Leasehold improvements are amortized on a straight-line basis over the shorter of the term of the lease, or the useful life of the assets.", '[0.017818017,-0.024016099,0.0042511695,0.03114478,0.003591422,-0.0097886855,0.02455732,0.013048866,0.018157514,-0.009401044,0.031699456,0.01678178,0. ... (4472 characters truncated) ... 6,0.01127416,0.045080125,-0.017046565,-0.028544193,-0.016320521,0.01062995,-0.021007432,-0.006999497,-0.08426073,-0.014918887,0.059064835,0.03307945]')]
In [ ]:
Copied!
# 结构化查询
response = query_engine.query(
"告诉我这个表中最大的页面编号是多少",
)
# 结构化查询
response = query_engine.query(
"告诉我这个表中最大的页面编号是多少",
)
In [ ]:
Copied!
print(str(response))
print(str(response))
The maximum page number in this table is 238.
In [ ]:
Copied!
print(response.metadata["sql_query"][:300])
print(response.metadata["sql_query"][:300])
SELECT MAX(page_label) FROM sec_text_chunk;