Skip to content

Sql

NLSQLRetriever #

Bases: BaseRetriever, PromptMixin

文本到SQL检索器。

通过文本检索。

Parameters:

Name Type Description Default
sql_database SQLDatabase

SQL数据库。

required
text_to_sql_prompt BasePromptTemplate

用于文本到SQL的提示模板。 默认为DEFAULT_TEXT_TO_SQL_PROMPT。

None
context_query_kwargs dict

表名到上下文查询的映射。 默认为None。

None
tables Union[List[str], List[Table]]

表名列表或Table对象的列表。

None
table_retriever ObjectRetriever[SQLTableSchema]

用于SQLTableSchema对象的对象检索器。 默认为None。

None
context_str_prefix str

上下文字符串的前缀。默认为None。

None
service_context ServiceContext

服务上下文。默认为None。

None
return_raw bool

是否返回SQL结果的纯文本转储,或解析为Nodes。

True
handle_sql_errors bool

是否处理SQL错误。默认为True。

True
sql_only bool)

是否仅获取SQL而不是SQL查询结果。 默认为False。

False
llm Optional[LLM]

要使用的语言模型。

None
Source code in llama_index/core/indices/struct_store/sql_retriever.py
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
class NLSQLRetriever(BaseRetriever, PromptMixin):
    """文本到SQL检索器。

    通过文本检索。

    Args:
        sql_database (SQLDatabase): SQL数据库。
        text_to_sql_prompt (BasePromptTemplate): 用于文本到SQL的提示模板。
            默认为DEFAULT_TEXT_TO_SQL_PROMPT。
        context_query_kwargs (dict): 表名到上下文查询的映射。
            默认为None。
        tables (Union[List[str], List[Table]]): 表名列表或Table对象的列表。
        table_retriever (ObjectRetriever[SQLTableSchema]): 用于SQLTableSchema对象的对象检索器。
            默认为None。
        context_str_prefix (str): 上下文字符串的前缀。默认为None。
        service_context (ServiceContext): 服务上下文。默认为None。
        return_raw (bool): 是否返回SQL结果的纯文本转储,或解析为Nodes。
        handle_sql_errors (bool): 是否处理SQL错误。默认为True。
        sql_only (bool) : 是否仅获取SQL而不是SQL查询结果。
            默认为False。
        llm (Optional[LLM]): 要使用的语言模型。"""

    def __init__(
        self,
        sql_database: SQLDatabase,
        text_to_sql_prompt: Optional[BasePromptTemplate] = None,
        context_query_kwargs: Optional[dict] = None,
        tables: Optional[Union[List[str], List[Table]]] = None,
        table_retriever: Optional[ObjectRetriever[SQLTableSchema]] = None,
        context_str_prefix: Optional[str] = None,
        sql_parser_mode: SQLParserMode = SQLParserMode.DEFAULT,
        llm: Optional[LLM] = None,
        embed_model: Optional[BaseEmbedding] = None,
        service_context: Optional[ServiceContext] = None,
        return_raw: bool = True,
        handle_sql_errors: bool = True,
        sql_only: bool = False,
        callback_manager: Optional[CallbackManager] = None,
        verbose: bool = False,
        **kwargs: Any,
    ) -> None:
        """初始化参数。"""
        self._sql_retriever = SQLRetriever(sql_database, return_raw=return_raw)
        self._sql_database = sql_database
        self._get_tables = self._load_get_tables_fn(
            sql_database, tables, context_query_kwargs, table_retriever
        )
        self._context_str_prefix = context_str_prefix
        self._llm = llm or llm_from_settings_or_context(Settings, service_context)
        self._text_to_sql_prompt = text_to_sql_prompt or DEFAULT_TEXT_TO_SQL_PROMPT
        self._sql_parser_mode = sql_parser_mode

        embed_model = embed_model or embed_model_from_settings_or_context(
            Settings, service_context
        )
        self._sql_parser = self._load_sql_parser(sql_parser_mode, embed_model)
        self._handle_sql_errors = handle_sql_errors
        self._sql_only = sql_only
        self._verbose = verbose
        super().__init__(
            callback_manager=callback_manager
            or callback_manager_from_settings_or_context(Settings, service_context)
        )

    def _get_prompts(self) -> Dict[str, Any]:
        """获取提示。"""
        return {
            "text_to_sql_prompt": self._text_to_sql_prompt,
        }

    def _update_prompts(self, prompts: PromptDictType) -> None:
        """更新提示。"""
        if "text_to_sql_prompt" in prompts:
            self._text_to_sql_prompt = prompts["text_to_sql_prompt"]

    def _get_prompt_modules(self) -> PromptMixinType:
        """获取提示模块。"""
        return {}

    def _load_sql_parser(
        self, sql_parser_mode: SQLParserMode, embed_model: BaseEmbedding
    ) -> BaseSQLParser:
        """加载SQL解析器。"""
        if sql_parser_mode == SQLParserMode.DEFAULT:
            return DefaultSQLParser()
        elif sql_parser_mode == SQLParserMode.PGVECTOR:
            return PGVectorSQLParser(embed_model=embed_model)
        else:
            raise ValueError(f"Unknown SQL parser mode: {sql_parser_mode}")

    def _load_get_tables_fn(
        self,
        sql_database: SQLDatabase,
        tables: Optional[Union[List[str], List[Table]]] = None,
        context_query_kwargs: Optional[dict] = None,
        table_retriever: Optional[ObjectRetriever[SQLTableSchema]] = None,
    ) -> Callable[[str], List[SQLTableSchema]]:
        """加载 get_tables 函数。"""
        context_query_kwargs = context_query_kwargs or {}
        if table_retriever is not None:
            return lambda query_str: cast(Any, table_retriever).retrieve(query_str)
        else:
            if tables is not None:
                table_names: List[str] = [
                    t.name if isinstance(t, Table) else t for t in tables
                ]
            else:
                table_names = list(sql_database.get_usable_table_names())
            context_strs = [context_query_kwargs.get(t, None) for t in table_names]
            table_schemas = [
                SQLTableSchema(table_name=t, context_str=c)
                for t, c in zip(table_names, context_strs)
            ]
            return lambda _: table_schemas

    def retrieve_with_metadata(
        self, str_or_query_bundle: QueryType
    ) -> Tuple[List[NodeWithScore], Dict]:
        """使用元数据检索。"""
        if isinstance(str_or_query_bundle, str):
            query_bundle = QueryBundle(str_or_query_bundle)
        else:
            query_bundle = str_or_query_bundle
        table_desc_str = self._get_table_context(query_bundle)
        logger.info(f"> Table desc str: {table_desc_str}")
        if self._verbose:
            print(f"> Table desc str: {table_desc_str}")

        response_str = self._llm.predict(
            self._text_to_sql_prompt,
            query_str=query_bundle.query_str,
            schema=table_desc_str,
            dialect=self._sql_database.dialect,
        )

        sql_query_str = self._sql_parser.parse_response_to_sql(
            response_str, query_bundle
        )
        # assume that it's a valid SQL query
        logger.debug(f"> Predicted SQL query: {sql_query_str}")
        if self._verbose:
            print(f"> Predicted SQL query: {sql_query_str}")

        if self._sql_only:
            sql_only_node = TextNode(text=f"{sql_query_str}")
            retrieved_nodes = [NodeWithScore(node=sql_only_node)]
            metadata = {"result": sql_query_str}
        else:
            try:
                retrieved_nodes, metadata = self._sql_retriever.retrieve_with_metadata(
                    sql_query_str
                )
            except BaseException as e:
                # if handle_sql_errors is True, then return error message
                if self._handle_sql_errors:
                    err_node = TextNode(text=f"Error: {e!s}")
                    retrieved_nodes = [NodeWithScore(node=err_node)]
                    metadata = {}
                else:
                    raise

        return retrieved_nodes, {"sql_query": sql_query_str, **metadata}

    async def aretrieve_with_metadata(
        self, str_or_query_bundle: QueryType
    ) -> Tuple[List[NodeWithScore], Dict]:
        """异步获取带有元数据。"""
        if isinstance(str_or_query_bundle, str):
            query_bundle = QueryBundle(str_or_query_bundle)
        else:
            query_bundle = str_or_query_bundle
        table_desc_str = self._get_table_context(query_bundle)
        logger.info(f"> Table desc str: {table_desc_str}")

        response_str = await self._llm.apredict(
            self._text_to_sql_prompt,
            query_str=query_bundle.query_str,
            schema=table_desc_str,
            dialect=self._sql_database.dialect,
        )

        sql_query_str = self._sql_parser.parse_response_to_sql(
            response_str, query_bundle
        )
        # assume that it's a valid SQL query
        logger.debug(f"> Predicted SQL query: {sql_query_str}")

        if self._sql_only:
            sql_only_node = TextNode(text=f"{sql_query_str}")
            retrieved_nodes = [NodeWithScore(node=sql_only_node)]
            metadata: Dict[str, Any] = {}
        else:
            try:
                (
                    retrieved_nodes,
                    metadata,
                ) = await self._sql_retriever.aretrieve_with_metadata(sql_query_str)
            except BaseException as e:
                # if handle_sql_errors is True, then return error message
                if self._handle_sql_errors:
                    err_node = TextNode(text=f"Error: {e!s}")
                    retrieved_nodes = [NodeWithScore(node=err_node)]
                    metadata = {}
                else:
                    raise
        return retrieved_nodes, {"sql_query": sql_query_str, **metadata}

    def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
        """检索给定查询的节点。"""
        retrieved_nodes, _ = self.retrieve_with_metadata(query_bundle)
        return retrieved_nodes

    async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
        """异步检索给定查询的节点。"""
        retrieved_nodes, _ = await self.aretrieve_with_metadata(query_bundle)
        return retrieved_nodes

    def _get_table_context(self, query_bundle: QueryBundle) -> str:
        """获取表上下文。

获取表模式 + 可选上下文,作为单个字符串。
"""
        table_schema_objs = self._get_tables(query_bundle.query_str)
        context_strs = []
        if self._context_str_prefix is not None:
            context_strs = [self._context_str_prefix]

        for table_schema_obj in table_schema_objs:
            table_info = self._sql_database.get_single_table_info(
                table_schema_obj.table_name
            )

            if table_schema_obj.context_str:
                table_opt_context = " The table description is: "
                table_opt_context += table_schema_obj.context_str
                table_info += table_opt_context

            context_strs.append(table_info)

        return "\n\n".join(context_strs)

retrieve_with_metadata #

retrieve_with_metadata(
    str_or_query_bundle: QueryType,
) -> Tuple[List[NodeWithScore], Dict]

使用元数据检索。

Source code in llama_index/core/indices/struct_store/sql_retriever.py
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
def retrieve_with_metadata(
    self, str_or_query_bundle: QueryType
) -> Tuple[List[NodeWithScore], Dict]:
    """使用元数据检索。"""
    if isinstance(str_or_query_bundle, str):
        query_bundle = QueryBundle(str_or_query_bundle)
    else:
        query_bundle = str_or_query_bundle
    table_desc_str = self._get_table_context(query_bundle)
    logger.info(f"> Table desc str: {table_desc_str}")
    if self._verbose:
        print(f"> Table desc str: {table_desc_str}")

    response_str = self._llm.predict(
        self._text_to_sql_prompt,
        query_str=query_bundle.query_str,
        schema=table_desc_str,
        dialect=self._sql_database.dialect,
    )

    sql_query_str = self._sql_parser.parse_response_to_sql(
        response_str, query_bundle
    )
    # assume that it's a valid SQL query
    logger.debug(f"> Predicted SQL query: {sql_query_str}")
    if self._verbose:
        print(f"> Predicted SQL query: {sql_query_str}")

    if self._sql_only:
        sql_only_node = TextNode(text=f"{sql_query_str}")
        retrieved_nodes = [NodeWithScore(node=sql_only_node)]
        metadata = {"result": sql_query_str}
    else:
        try:
            retrieved_nodes, metadata = self._sql_retriever.retrieve_with_metadata(
                sql_query_str
            )
        except BaseException as e:
            # if handle_sql_errors is True, then return error message
            if self._handle_sql_errors:
                err_node = TextNode(text=f"Error: {e!s}")
                retrieved_nodes = [NodeWithScore(node=err_node)]
                metadata = {}
            else:
                raise

    return retrieved_nodes, {"sql_query": sql_query_str, **metadata}

aretrieve_with_metadata async #

aretrieve_with_metadata(
    str_or_query_bundle: QueryType,
) -> Tuple[List[NodeWithScore], Dict]

异步获取带有元数据。

Source code in llama_index/core/indices/struct_store/sql_retriever.py
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
async def aretrieve_with_metadata(
    self, str_or_query_bundle: QueryType
) -> Tuple[List[NodeWithScore], Dict]:
    """异步获取带有元数据。"""
    if isinstance(str_or_query_bundle, str):
        query_bundle = QueryBundle(str_or_query_bundle)
    else:
        query_bundle = str_or_query_bundle
    table_desc_str = self._get_table_context(query_bundle)
    logger.info(f"> Table desc str: {table_desc_str}")

    response_str = await self._llm.apredict(
        self._text_to_sql_prompt,
        query_str=query_bundle.query_str,
        schema=table_desc_str,
        dialect=self._sql_database.dialect,
    )

    sql_query_str = self._sql_parser.parse_response_to_sql(
        response_str, query_bundle
    )
    # assume that it's a valid SQL query
    logger.debug(f"> Predicted SQL query: {sql_query_str}")

    if self._sql_only:
        sql_only_node = TextNode(text=f"{sql_query_str}")
        retrieved_nodes = [NodeWithScore(node=sql_only_node)]
        metadata: Dict[str, Any] = {}
    else:
        try:
            (
                retrieved_nodes,
                metadata,
            ) = await self._sql_retriever.aretrieve_with_metadata(sql_query_str)
        except BaseException as e:
            # if handle_sql_errors is True, then return error message
            if self._handle_sql_errors:
                err_node = TextNode(text=f"Error: {e!s}")
                retrieved_nodes = [NodeWithScore(node=err_node)]
                metadata = {}
            else:
                raise
    return retrieved_nodes, {"sql_query": sql_query_str, **metadata}

SQLParserMode #

Bases: str, Enum

SQL解析器模式。

Source code in llama_index/core/indices/struct_store/sql_retriever.py
102
103
104
105
106
class SQLParserMode(str, Enum):
    """SQL解析器模式。"""

    DEFAULT = "default"
    PGVECTOR = "pgvector"

SQLRetriever #

Bases: BaseRetriever

SQL检索器。

通过原始SQL语句检索。

Source code in llama_index/core/indices/struct_store/sql_retriever.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
class SQLRetriever(BaseRetriever):
    """SQL检索器。

    通过原始SQL语句检索。

    Args:
        sql_database(SQLDatabase):SQL数据库。
        return_raw(bool):是否返回原始结果或格式化结果。默认为True。"""

    def __init__(
        self,
        sql_database: SQLDatabase,
        return_raw: bool = True,
        callback_manager: Optional[CallbackManager] = None,
        **kwargs: Any,
    ) -> None:
        """初始化参数。"""
        self._sql_database = sql_database
        self._return_raw = return_raw
        super().__init__(callback_manager)

    def _format_node_results(
        self, results: List[List[Any]], col_keys: List[str]
    ) -> List[NodeWithScore]:
        """格式化节点结果。"""
        nodes = []
        for result in results:
            # associate column keys with result tuple
            metadata = dict(zip(col_keys, result))
            # NOTE: leave text field blank for now
            text_node = TextNode(
                text="",
                metadata=metadata,
            )
            nodes.append(NodeWithScore(node=text_node))
        return nodes

    def retrieve_with_metadata(
        self, str_or_query_bundle: QueryType
    ) -> Tuple[List[NodeWithScore], Dict]:
        """使用元数据检索。"""
        if isinstance(str_or_query_bundle, str):
            query_bundle = QueryBundle(str_or_query_bundle)
        else:
            query_bundle = str_or_query_bundle
        raw_response_str, metadata = self._sql_database.run_sql(query_bundle.query_str)
        if self._return_raw:
            return [NodeWithScore(node=TextNode(text=raw_response_str))], metadata
        else:
            # return formatted
            results = metadata["result"]
            col_keys = metadata["col_keys"]
            return self._format_node_results(results, col_keys), metadata

    async def aretrieve_with_metadata(
        self, str_or_query_bundle: QueryType
    ) -> Tuple[List[NodeWithScore], Dict]:
        return self.retrieve_with_metadata(str_or_query_bundle)

    def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
        """检索给定查询的节点。"""
        retrieved_nodes, _ = self.retrieve_with_metadata(query_bundle)
        return retrieved_nodes

retrieve_with_metadata #

retrieve_with_metadata(
    str_or_query_bundle: QueryType,
) -> Tuple[List[NodeWithScore], Dict]

使用元数据检索。

Source code in llama_index/core/indices/struct_store/sql_retriever.py
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def retrieve_with_metadata(
    self, str_or_query_bundle: QueryType
) -> Tuple[List[NodeWithScore], Dict]:
    """使用元数据检索。"""
    if isinstance(str_or_query_bundle, str):
        query_bundle = QueryBundle(str_or_query_bundle)
    else:
        query_bundle = str_or_query_bundle
    raw_response_str, metadata = self._sql_database.run_sql(query_bundle.query_str)
    if self._return_raw:
        return [NodeWithScore(node=TextNode(text=raw_response_str))], metadata
    else:
        # return formatted
        results = metadata["result"]
        col_keys = metadata["col_keys"]
        return self._format_node_results(results, col_keys), metadata