Source code for langchain_community.utilities.sql_database

"""SQLAlchemy对数据库的封装。"""
from __future__ import annotations

from typing import Any, Dict, Iterable, List, Literal, Optional, Sequence, Union

import sqlalchemy
from langchain_core._api import deprecated
from langchain_core.utils import get_from_env
from sqlalchemy import (
    MetaData,
    Table,
    create_engine,
    inspect,
    select,
    text,
)
from sqlalchemy.engine import Engine, Result
from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
from sqlalchemy.schema import CreateTable
from sqlalchemy.sql.expression import Executable
from sqlalchemy.types import NullType


def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str:
    return (
        f'Name: {index["name"]}, Unique: {index["unique"]},'
        f' Columns: {str(index["column_names"])}'
    )


[docs]def truncate_word(content: Any, *, length: int, suffix: str = "...") -> str: """根据最大字符串长度,将字符串截断为一定数量的单词。 """ if not isinstance(content, str) or length <= 0: return content if len(content) <= length: return content return content[: length - len(suffix)].rsplit(" ", 1)[0] + suffix
[docs]class SQLDatabase: """SQLAlchemy对数据库的封装。"""
[docs] def __init__( self, engine: Engine, schema: Optional[str] = None, metadata: Optional[MetaData] = None, ignore_tables: Optional[List[str]] = None, include_tables: Optional[List[str]] = None, sample_rows_in_table_info: int = 3, indexes_in_table_info: bool = False, custom_table_info: Optional[dict] = None, view_support: bool = False, max_string_length: int = 300, lazy_table_reflection: bool = False, ): """从数据库URI创建引擎。""" self._engine = engine self._schema = schema if include_tables and ignore_tables: raise ValueError("Cannot specify both include_tables and ignore_tables") self._inspector = inspect(self._engine) # including view support by adding the views as well as tables to the all # tables list if view_support is True self._all_tables = set( self._inspector.get_table_names(schema=schema) + (self._inspector.get_view_names(schema=schema) if view_support else []) ) self._include_tables = set(include_tables) if include_tables else set() if self._include_tables: missing_tables = self._include_tables - self._all_tables if missing_tables: raise ValueError( f"include_tables {missing_tables} not found in database" ) self._ignore_tables = set(ignore_tables) if ignore_tables else set() if self._ignore_tables: missing_tables = self._ignore_tables - self._all_tables if missing_tables: raise ValueError( f"ignore_tables {missing_tables} not found in database" ) usable_tables = self.get_usable_table_names() self._usable_tables = set(usable_tables) if usable_tables else self._all_tables if not isinstance(sample_rows_in_table_info, int): raise TypeError("sample_rows_in_table_info must be an integer") self._sample_rows_in_table_info = sample_rows_in_table_info self._indexes_in_table_info = indexes_in_table_info self._custom_table_info = custom_table_info if self._custom_table_info: if not isinstance(self._custom_table_info, dict): raise TypeError( "table_info must be a dictionary with table names as keys and the " "desired table info as values" ) # only keep the tables that are also present in the database intersection = set(self._custom_table_info).intersection(self._all_tables) self._custom_table_info = dict( (table, self._custom_table_info[table]) for table in self._custom_table_info if table in intersection ) self._max_string_length = max_string_length self._view_support = view_support self._metadata = metadata or MetaData() if not lazy_table_reflection: # including view support if view_support = true self._metadata.reflect( views=view_support, bind=self._engine, only=list(self._usable_tables), schema=self._schema, )
[docs] @classmethod def from_uri( cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any ) -> SQLDatabase: """从URI构建一个SQLAlchemy引擎。""" _engine_args = engine_args or {} return cls(create_engine(database_uri, **_engine_args), **kwargs)
[docs] @classmethod def from_databricks( cls, catalog: str, schema: str, host: Optional[str] = None, api_token: Optional[str] = None, warehouse_id: Optional[str] = None, cluster_id: Optional[str] = None, engine_args: Optional[dict] = None, **kwargs: Any, ) -> SQLDatabase: """用于从Databricks连接创建SQLDatabase实例的类方法。 此方法需要'databricks-sql-connector'包。如果未安装, 可以使用`pip install databricks-sql-connector`进行添加。 参数: catalog(str):Databricks数据库中的目录名称。 schema(str):目录中的模式名称。 host(可选[str]):Databricks工作区主机名,不包括 'https://'部分。如果未提供,则尝试从 环境变量'DATABRICKS_HOST'中获取。如果仍然不可用且 在Databricks笔记本中运行,则默认为当前工作区 主机名。默认为None。 api_token(可选[str]):用于 访问Databricks SQL仓库或集群的Databricks个人访问令牌。如果未提供, 尝试从'DATABRICKS_TOKEN'中获取。如果仍然不可用 并在Databricks笔记本中运行,则为当前用户生成临时令牌。 默认为None。 warehouse_id(可选[str]):Databricks SQL中的仓库ID。如果 提供,则该方法配置连接以使用此仓库。 不能与'cluster_id'一起使用。默认为None。 cluster_id(可选[str]):Databricks Runtime中的集群ID。如果 提供,则该方法配置连接以使用此集群。 不能与'warehouse_id'一起使用。如果在Databricks笔记本中运行 并且'warehouse_id'和'cluster_id'都为None,则使用笔记本所附加的 集群的ID。默认为None。 engine_args(可选[dict]):连接时要使用的参数 Databricks。默认为None。 **kwargs(Any):`from_uri`方法的其他关键字参数。 返回: SQLDatabase:使用提供的Databricks连接详细信息配置的 SQLDatabase实例。 引发: ValueError:如果未找到'databricks-sql-connector',或者如果 提供了'warehouse_id'和'cluster_id',或者如果既没有 提供'warehouse_id'也没有提供'cluster_id'且未在 Databricks笔记本中执行。 """ try: from databricks import sql # noqa: F401 except ImportError: raise ImportError( "databricks-sql-connector package not found, please install with" " `pip install databricks-sql-connector`" ) context = None try: from dbruntime.databricks_repl_context import get_context context = get_context() except ImportError: pass default_host = context.browserHostName if context else None if host is None: host = get_from_env("host", "DATABRICKS_HOST", default_host) default_api_token = context.apiToken if context else None if api_token is None: api_token = get_from_env("api_token", "DATABRICKS_TOKEN", default_api_token) if warehouse_id is None and cluster_id is None: if context: cluster_id = context.clusterId else: raise ValueError( "Need to provide either 'warehouse_id' or 'cluster_id'." ) if warehouse_id and cluster_id: raise ValueError("Can't have both 'warehouse_id' or 'cluster_id'.") if warehouse_id: http_path = f"/sql/1.0/warehouses/{warehouse_id}" else: http_path = f"/sql/protocolv1/o/0/{cluster_id}" uri = ( f"databricks://token:{api_token}@{host}?" f"http_path={http_path}&catalog={catalog}&schema={schema}" ) return cls.from_uri(database_uri=uri, engine_args=engine_args, **kwargs)
[docs] @classmethod def from_cnosdb( cls, url: str = "127.0.0.1:8902", user: str = "root", password: str = "", tenant: str = "cnosdb", database: str = "public", ) -> SQLDatabase: """用于从CnosDB连接创建SQLDatabase实例的类方法。 此方法需要'cnos-connector'包。如果未安装,可以使用`pip install cnos-connector`添加。 参数: url(str):CnosDB服务的HTTP连接主机名和端口号,不包括"http://"或"https://",默认值为"127.0.0.1:8902"。 user(str):用于连接到CnosDB服务的用户名,默认值为"root"。 password(str):连接到CnosDB服务的用户的密码,默认值为""。 tenant(str):用于连接到CnosDB服务的租户名称,默认值为"cnosdb"。 database(str):CnosDB租户中数据库的名称。 返回: SQLDatabase:使用提供的CnosDB连接详细信息配置的SQLDatabase实例。 """ try: from cnosdb_connector import make_cnosdb_langchain_uri uri = make_cnosdb_langchain_uri(url, user, password, tenant, database) return cls.from_uri(database_uri=uri) except ImportError: raise ImportError( "cnos-connector package not found, please install with" " `pip install cnos-connector`" )
@property def dialect(self) -> str: """返回要使用的方言的字符串表示。""" return self._engine.dialect.name
[docs] def get_usable_table_names(self) -> Iterable[str]: """获取可用表的名称。""" if self._include_tables: return sorted(self._include_tables) return sorted(self._all_tables - self._ignore_tables)
[docs] @deprecated("0.0.1", alternative="get_usable_table_names", removal="0.3.0") def get_table_names(self) -> Iterable[str]: """获取可用表的名称。""" return self.get_usable_table_names()
@property def table_info(self) -> str: """数据库中所有表的信息。""" return self.get_table_info()
[docs] def get_table_info(self, table_names: Optional[List[str]] = None) -> str: """获取指定表的信息。 遵循Rajkumar等人在2022年指定的最佳实践(https://arxiv.org/abs/2204.00498)。 如果`sample_rows_in_table_info`,则将指定数量的样本行附加到每个表描述中。 正如论文中所示,这可以提高性能。 """ all_table_names = self.get_usable_table_names() if table_names is not None: missing_tables = set(table_names).difference(all_table_names) if missing_tables: raise ValueError(f"table_names {missing_tables} not found in database") all_table_names = table_names metadata_table_names = [tbl.name for tbl in self._metadata.sorted_tables] to_reflect = set(all_table_names) - set(metadata_table_names) if to_reflect: self._metadata.reflect( views=self._view_support, bind=self._engine, only=list(to_reflect), schema=self._schema, ) meta_tables = [ tbl for tbl in self._metadata.sorted_tables if tbl.name in set(all_table_names) and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_")) ] tables = [] for table in meta_tables: if self._custom_table_info and table.name in self._custom_table_info: tables.append(self._custom_table_info[table.name]) continue # Ignore JSON datatyped columns for k, v in table.columns.items(): if type(v.type) is NullType: table._columns.remove(v) # add create table command create_table = str(CreateTable(table).compile(self._engine)) table_info = f"{create_table.rstrip()}" has_extra_info = ( self._indexes_in_table_info or self._sample_rows_in_table_info ) if has_extra_info: table_info += "\n\n/*" if self._indexes_in_table_info: table_info += f"\n{self._get_table_indexes(table)}\n" if self._sample_rows_in_table_info: table_info += f"\n{self._get_sample_rows(table)}\n" if has_extra_info: table_info += "*/" tables.append(table_info) tables.sort() final_str = "\n\n".join(tables) return final_str
def _get_table_indexes(self, table: Table) -> str: indexes = self._inspector.get_indexes(table.name) indexes_formatted = "\n".join(map(_format_index, indexes)) return f"Table Indexes:\n{indexes_formatted}" def _get_sample_rows(self, table: Table) -> str: # build the select command command = select(table).limit(self._sample_rows_in_table_info) # save the columns in string format columns_str = "\t".join([col.name for col in table.columns]) try: # get the sample rows with self._engine.connect() as connection: sample_rows_result = connection.execute(command) # type: ignore # shorten values in the sample rows sample_rows = list( map(lambda ls: [str(i)[:100] for i in ls], sample_rows_result) ) # save the sample rows in string format sample_rows_str = "\n".join(["\t".join(row) for row in sample_rows]) # in some dialects when there are no rows in the table a # 'ProgrammingError' is returned except ProgrammingError: sample_rows_str = "" return ( f"{self._sample_rows_in_table_info} rows from {table.name} table:\n" f"{columns_str}\n" f"{sample_rows_str}" ) def _execute( self, command: Union[str, Executable], fetch: Literal["all", "one", "cursor"] = "all", *, parameters: Optional[Dict[str, Any]] = None, execution_options: Optional[Dict[str, Any]] = None, ) -> Union[Sequence[Dict[str, Any]], Result]: """通过底层引擎执行SQL命令。 如果语句没有返回任何行,则返回一个空列表。 """ parameters = parameters or {} execution_options = execution_options or {} with self._engine.begin() as connection: # type: Connection # type: ignore[name-defined] if self._schema is not None: if self.dialect == "snowflake": connection.exec_driver_sql( "ALTER SESSION SET search_path = %s", (self._schema,), execution_options=execution_options, ) elif self.dialect == "bigquery": connection.exec_driver_sql( "SET @@dataset_id=?", (self._schema,), execution_options=execution_options, ) elif self.dialect == "mssql": pass elif self.dialect == "trino": connection.exec_driver_sql( "USE ?", (self._schema,), execution_options=execution_options, ) elif self.dialect == "duckdb": # Unclear which parameterized argument syntax duckdb supports. # The docs for the duckdb client say they support multiple, # but `duckdb_engine` seemed to struggle with all of them: # https://github.com/Mause/duckdb_engine/issues/796 connection.exec_driver_sql( f"SET search_path TO {self._schema}", execution_options=execution_options, ) elif self.dialect == "oracle": connection.exec_driver_sql( f"ALTER SESSION SET CURRENT_SCHEMA = {self._schema}", execution_options=execution_options, ) elif self.dialect == "sqlany": # If anybody using Sybase SQL anywhere database then it should not # go to else condition. It should be same as mssql. pass elif self.dialect == "postgresql": # postgresql connection.exec_driver_sql( "SET search_path TO %s", (self._schema,), execution_options=execution_options, ) if isinstance(command, str): command = text(command) elif isinstance(command, Executable): pass else: raise TypeError(f"Query expression has unknown type: {type(command)}") cursor = connection.execute( command, parameters, execution_options=execution_options, ) if cursor.returns_rows: if fetch == "all": result = [x._asdict() for x in cursor.fetchall()] elif fetch == "one": first_result = cursor.fetchone() result = [] if first_result is None else [first_result._asdict()] elif fetch == "cursor": return cursor else: raise ValueError( "Fetch parameter must be either 'one', 'all', or 'cursor'" ) return result return []
[docs] def run( self, command: Union[str, Executable], fetch: Literal["all", "one", "cursor"] = "all", include_columns: bool = False, *, parameters: Optional[Dict[str, Any]] = None, execution_options: Optional[Dict[str, Any]] = None, ) -> Union[str, Sequence[Dict[str, Any]], Result[Any]]: """执行一个SQL命令并返回表示结果的字符串。 如果语句返回行,则返回结果的字符串。 如果语句没有返回行,则返回空字符串。 """ result = self._execute( command, fetch, parameters=parameters, execution_options=execution_options ) if fetch == "cursor": return result res = [ { column: truncate_word(value, length=self._max_string_length) for column, value in r.items() } for r in result ] if not include_columns: res = [tuple(row.values()) for row in res] # type: ignore[misc] if not res: return "" else: return str(res)
[docs] def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str: """获取指定表的信息。 遵循Rajkumar等人在2022年指定的最佳实践(https://arxiv.org/abs/2204.00498)。 如果`sample_rows_in_table_info`,则将指定数量的样本行附加到每个表描述中。 正如论文中所示,这可以提高性能。 """ try: return self.get_table_info(table_names) except ValueError as e: """Format the error message""" return f"Error: {e}"
[docs] def run_no_throw( self, command: str, fetch: Literal["all", "one"] = "all", include_columns: bool = False, *, parameters: Optional[Dict[str, Any]] = None, execution_options: Optional[Dict[str, Any]] = None, ) -> Union[str, Sequence[Dict[str, Any]], Result[Any]]: """执行一个SQL命令并返回表示结果的字符串。 如果语句返回行,则返回结果的字符串。 如果语句未返回行,则返回空字符串。 如果语句抛出错误,则返回错误消息。 """ try: return self.run( command, fetch, parameters=parameters, execution_options=execution_options, include_columns=include_columns, ) except SQLAlchemyError as e: """Format the error message""" return f"Error: {e}"
[docs] def get_context(self) -> Dict[str, Any]: """返回可能在代理提示中需要的数据库上下文。""" table_names = list(self.get_usable_table_names()) table_info = self.get_table_info_no_throw() return {"table_info": table_info, "table_names": ", ".join(table_names)}