Source code for langchain_experimental.sql.vector_sql

"""矢量SQL数据库链检索器"""

from __future__ import annotations

from typing import Any, Dict, List, Optional, Sequence, Union

from langchain.chains.llm import LLMChain
from langchain.chains.sql_database.prompt import PROMPT, SQL_PROMPTS
from langchain_community.tools.sql_database.prompt import QUERY_CHECKER
from langchain_community.utilities.sql_database import SQLDatabase
from langchain_core.callbacks.manager import CallbackManagerForChainRun
from langchain_core.embeddings import Embeddings
from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.prompts import BasePromptTemplate
from langchain_core.prompts.prompt import PromptTemplate

from langchain_experimental.sql.base import INTERMEDIATE_STEPS_KEY, SQLDatabaseChain


[docs]class VectorSQLOutputParser(BaseOutputParser[str]): """输出结果 用于Vector SQL的输出解析器。 1. 查找`NeuralArray()`并将其替换为嵌入 2. 查找`DISTANCE()`并将其替换为后端SQL中的距离名称 """ model: Embeddings """将模型嵌入以提取实体的嵌入向量""" distance_func_name: str = "distance" """Vector SQL的距离名称""" class Config: arbitrary_types_allowed = 1 @property def _type(self) -> str: return "vector_sql_parser"
[docs] @classmethod def from_embeddings( cls, model: Embeddings, distance_func_name: str = "distance", **kwargs: Any ) -> BaseOutputParser: return cls(model=model, distance_func_name=distance_func_name, **kwargs)
[docs] def parse(self, text: str) -> str: text = text.strip() start = text.find("NeuralArray(") _sql_str_compl = text if start > 0: _matched = text[text.find("NeuralArray(") + len("NeuralArray(") :] end = _matched.find(")") + start + len("NeuralArray(") + 1 entity = _matched[: _matched.find(")")] vecs = self.model.embed_query(entity) vecs_str = "[" + ",".join(map(str, vecs)) + "]" _sql_str_compl = text.replace("DISTANCE", self.distance_func_name).replace( text[start:end], vecs_str ) if _sql_str_compl[-1] == ";": _sql_str_compl = _sql_str_compl[:-1] return _sql_str_compl
[docs]class VectorSQLRetrieveAllOutputParser(VectorSQLOutputParser): """基于VectorSQLOutputParser的解析器。 它还修改SQL以获取所有列。 """ @property def _type(self) -> str: return "vector_sql_retrieve_all_parser"
[docs] def parse(self, text: str) -> str: text = text.strip() start = text.upper().find("SELECT") if start >= 0: end = text.upper().find("FROM") text = text.replace(text[start + len("SELECT") + 1 : end - 1], "*") return super().parse(text)
[docs]def get_result_from_sqldb(db: SQLDatabase, cmd: str) -> Sequence[Dict[str, Any]]: """从SQL数据库获取结果。""" result = db._execute(cmd, fetch="all") assert isinstance(result, Sequence) return result
[docs]class VectorSQLDatabaseChain(SQLDatabaseChain): """与Vector SQL数据库交互的链。 示例: .. code-block:: python from langchain_experimental.sql import SQLDatabaseChain from langchain_community.llms import OpenAI, SQLDatabase, OpenAIEmbeddings db = SQLDatabase(...) db_chain = VectorSQLDatabaseChain.from_llm(OpenAI(), db, OpenAIEmbeddings()) *安全提示*:确保数据库连接使用的凭据仅限于包括此链需要的权限。如果未能这样做,可能会导致数据损坏或丢失,因为此链可能会尝试执行诸如`DROP TABLE`或`INSERT`之类的命令,如果适当提示的话。防范这些负面结果的最佳方法是(视情况)限制授予此链使用的凭据的权限。如果未采取这些步骤,此问题将显示一个负面结果的示例:https://github.com/langchain-ai/langchain/issues/5923 """ sql_cmd_parser: VectorSQLOutputParser """Vector SQL的解析器""" native_format: bool = False """如果return_direct为True,则控制是否以Python原生格式返回""" def _call( self, inputs: Dict[str, Any], run_manager: Optional[CallbackManagerForChainRun] = None, ) -> Dict[str, Any]: _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() input_text = f"{inputs[self.input_key]}\nSQLQuery:" _run_manager.on_text(input_text, verbose=self.verbose) # If not present, then defaults to None which is all tables. table_names_to_use = inputs.get("table_names_to_use") table_info = self.database.get_table_info(table_names=table_names_to_use) llm_inputs = { "input": input_text, "top_k": str(self.top_k), "dialect": self.database.dialect, "table_info": table_info, "stop": ["\nSQLResult:"], } intermediate_steps: List = [] try: intermediate_steps.append(llm_inputs) # input: sql generation llm_out = self.llm_chain.predict( callbacks=_run_manager.get_child(), **llm_inputs, ) sql_cmd = self.sql_cmd_parser.parse(llm_out) if self.return_sql: return {self.output_key: sql_cmd} if not self.use_query_checker: _run_manager.on_text(llm_out, color="green", verbose=self.verbose) intermediate_steps.append( llm_out ) # output: sql generation (no checker) intermediate_steps.append({"sql_cmd": llm_out}) # input: sql exec result = get_result_from_sqldb(self.database, sql_cmd) intermediate_steps.append(str(result)) # output: sql exec else: query_checker_prompt = self.query_checker_prompt or PromptTemplate( template=QUERY_CHECKER, input_variables=["query", "dialect"] ) query_checker_chain = LLMChain( llm=self.llm_chain.llm, prompt=query_checker_prompt, output_parser=self.llm_chain.output_parser, ) query_checker_inputs = { "query": llm_out, "dialect": self.database.dialect, } checked_llm_out = query_checker_chain.predict( callbacks=_run_manager.get_child(), **query_checker_inputs ) checked_sql_command = self.sql_cmd_parser.parse(checked_llm_out) intermediate_steps.append( checked_llm_out ) # output: sql generation (checker) _run_manager.on_text( checked_llm_out, color="green", verbose=self.verbose ) intermediate_steps.append( {"sql_cmd": checked_llm_out} ) # input: sql exec result = get_result_from_sqldb(self.database, checked_sql_command) intermediate_steps.append(str(result)) # output: sql exec llm_out = checked_llm_out sql_cmd = checked_sql_command _run_manager.on_text("\nSQLResult: ", verbose=self.verbose) _run_manager.on_text(str(result), color="yellow", verbose=self.verbose) # If return direct, we just set the final result equal to # the result of the sql query result (`Sequence[Dict[str, Any]]`), # otherwise try to get a human readable final answer (`str`). final_result: Union[str, Sequence[Dict[str, Any]]] if self.return_direct: final_result = result else: _run_manager.on_text("\nAnswer:", verbose=self.verbose) input_text += f"{llm_out}\nSQLResult: {result}\nAnswer:" llm_inputs["input"] = input_text intermediate_steps.append(llm_inputs) # input: final answer final_result = self.llm_chain.predict( callbacks=_run_manager.get_child(), **llm_inputs, ).strip() intermediate_steps.append(final_result) # output: final answer _run_manager.on_text(final_result, color="green", verbose=self.verbose) chain_result: Dict[str, Any] = {self.output_key: final_result} if self.return_intermediate_steps: chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps return chain_result except Exception as exc: # Append intermediate steps to exception, to aid in logging and later # improvement of few shot prompt seeds exc.intermediate_steps = intermediate_steps # type: ignore raise exc @property def _chain_type(self) -> str: return "vector_sql_database_chain"
[docs] @classmethod def from_llm( cls, llm: BaseLanguageModel, db: SQLDatabase, prompt: Optional[BasePromptTemplate] = None, sql_cmd_parser: Optional[VectorSQLOutputParser] = None, **kwargs: Any, ) -> VectorSQLDatabaseChain: assert sql_cmd_parser, "`sql_cmd_parser` must be set in VectorSQLDatabaseChain." prompt = prompt or SQL_PROMPTS.get(db.dialect, PROMPT) llm_chain = LLMChain(llm=llm, prompt=prompt) return cls( llm_chain=llm_chain, database=db, sql_cmd_parser=sql_cmd_parser, **kwargs )