"""SQLAlchemy对数据库的封装。"""
from __future__ import annotations
from typing import Any, Dict, Iterable, List, Literal, Optional, Sequence, Union
import sqlalchemy
from langchain_core._api import deprecated
from langchain_core.utils import get_from_env
from sqlalchemy import (
MetaData,
Table,
create_engine,
inspect,
select,
text,
)
from sqlalchemy.engine import Engine, Result
from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
from sqlalchemy.schema import CreateTable
from sqlalchemy.sql.expression import Executable
from sqlalchemy.types import NullType
def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str:
return (
f'Name: {index["name"]}, Unique: {index["unique"]},'
f' Columns: {str(index["column_names"])}'
)
[docs]def truncate_word(content: Any, *, length: int, suffix: str = "...") -> str:
"""根据最大字符串长度,将字符串截断为一定数量的单词。
"""
if not isinstance(content, str) or length <= 0:
return content
if len(content) <= length:
return content
return content[: length - len(suffix)].rsplit(" ", 1)[0] + suffix
[docs]class SQLDatabase:
"""SQLAlchemy对数据库的封装。"""
[docs] def __init__(
self,
engine: Engine,
schema: Optional[str] = None,
metadata: Optional[MetaData] = None,
ignore_tables: Optional[List[str]] = None,
include_tables: Optional[List[str]] = None,
sample_rows_in_table_info: int = 3,
indexes_in_table_info: bool = False,
custom_table_info: Optional[dict] = None,
view_support: bool = False,
max_string_length: int = 300,
lazy_table_reflection: bool = False,
):
"""从数据库URI创建引擎。"""
self._engine = engine
self._schema = schema
if include_tables and ignore_tables:
raise ValueError("Cannot specify both include_tables and ignore_tables")
self._inspector = inspect(self._engine)
# including view support by adding the views as well as tables to the all
# tables list if view_support is True
self._all_tables = set(
self._inspector.get_table_names(schema=schema)
+ (self._inspector.get_view_names(schema=schema) if view_support else [])
)
self._include_tables = set(include_tables) if include_tables else set()
if self._include_tables:
missing_tables = self._include_tables - self._all_tables
if missing_tables:
raise ValueError(
f"include_tables {missing_tables} not found in database"
)
self._ignore_tables = set(ignore_tables) if ignore_tables else set()
if self._ignore_tables:
missing_tables = self._ignore_tables - self._all_tables
if missing_tables:
raise ValueError(
f"ignore_tables {missing_tables} not found in database"
)
usable_tables = self.get_usable_table_names()
self._usable_tables = set(usable_tables) if usable_tables else self._all_tables
if not isinstance(sample_rows_in_table_info, int):
raise TypeError("sample_rows_in_table_info must be an integer")
self._sample_rows_in_table_info = sample_rows_in_table_info
self._indexes_in_table_info = indexes_in_table_info
self._custom_table_info = custom_table_info
if self._custom_table_info:
if not isinstance(self._custom_table_info, dict):
raise TypeError(
"table_info must be a dictionary with table names as keys and the "
"desired table info as values"
)
# only keep the tables that are also present in the database
intersection = set(self._custom_table_info).intersection(self._all_tables)
self._custom_table_info = dict(
(table, self._custom_table_info[table])
for table in self._custom_table_info
if table in intersection
)
self._max_string_length = max_string_length
self._view_support = view_support
self._metadata = metadata or MetaData()
if not lazy_table_reflection:
# including view support if view_support = true
self._metadata.reflect(
views=view_support,
bind=self._engine,
only=list(self._usable_tables),
schema=self._schema,
)
[docs] @classmethod
def from_uri(
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
) -> SQLDatabase:
"""从URI构建一个SQLAlchemy引擎。"""
_engine_args = engine_args or {}
return cls(create_engine(database_uri, **_engine_args), **kwargs)
[docs] @classmethod
def from_databricks(
cls,
catalog: str,
schema: str,
host: Optional[str] = None,
api_token: Optional[str] = None,
warehouse_id: Optional[str] = None,
cluster_id: Optional[str] = None,
engine_args: Optional[dict] = None,
**kwargs: Any,
) -> SQLDatabase:
"""用于从Databricks连接创建SQLDatabase实例的类方法。
此方法需要'databricks-sql-connector'包。如果未安装,
可以使用`pip install databricks-sql-connector`进行添加。
参数:
catalog(str):Databricks数据库中的目录名称。
schema(str):目录中的模式名称。
host(可选[str]):Databricks工作区主机名,不包括
'https://'部分。如果未提供,则尝试从
环境变量'DATABRICKS_HOST'中获取。如果仍然不可用且
在Databricks笔记本中运行,则默认为当前工作区
主机名。默认为None。
api_token(可选[str]):用于
访问Databricks SQL仓库或集群的Databricks个人访问令牌。如果未提供,
尝试从'DATABRICKS_TOKEN'中获取。如果仍然不可用
并在Databricks笔记本中运行,则为当前用户生成临时令牌。
默认为None。
warehouse_id(可选[str]):Databricks SQL中的仓库ID。如果
提供,则该方法配置连接以使用此仓库。
不能与'cluster_id'一起使用。默认为None。
cluster_id(可选[str]):Databricks Runtime中的集群ID。如果
提供,则该方法配置连接以使用此集群。
不能与'warehouse_id'一起使用。如果在Databricks笔记本中运行
并且'warehouse_id'和'cluster_id'都为None,则使用笔记本所附加的
集群的ID。默认为None。
engine_args(可选[dict]):连接时要使用的参数
Databricks。默认为None。
**kwargs(Any):`from_uri`方法的其他关键字参数。
返回:
SQLDatabase:使用提供的Databricks连接详细信息配置的
SQLDatabase实例。
引发:
ValueError:如果未找到'databricks-sql-connector',或者如果
提供了'warehouse_id'和'cluster_id',或者如果既没有
提供'warehouse_id'也没有提供'cluster_id'且未在
Databricks笔记本中执行。
"""
try:
from databricks import sql # noqa: F401
except ImportError:
raise ImportError(
"databricks-sql-connector package not found, please install with"
" `pip install databricks-sql-connector`"
)
context = None
try:
from dbruntime.databricks_repl_context import get_context
context = get_context()
except ImportError:
pass
default_host = context.browserHostName if context else None
if host is None:
host = get_from_env("host", "DATABRICKS_HOST", default_host)
default_api_token = context.apiToken if context else None
if api_token is None:
api_token = get_from_env("api_token", "DATABRICKS_TOKEN", default_api_token)
if warehouse_id is None and cluster_id is None:
if context:
cluster_id = context.clusterId
else:
raise ValueError(
"Need to provide either 'warehouse_id' or 'cluster_id'."
)
if warehouse_id and cluster_id:
raise ValueError("Can't have both 'warehouse_id' or 'cluster_id'.")
if warehouse_id:
http_path = f"/sql/1.0/warehouses/{warehouse_id}"
else:
http_path = f"/sql/protocolv1/o/0/{cluster_id}"
uri = (
f"databricks://token:{api_token}@{host}?"
f"http_path={http_path}&catalog={catalog}&schema={schema}"
)
return cls.from_uri(database_uri=uri, engine_args=engine_args, **kwargs)
[docs] @classmethod
def from_cnosdb(
cls,
url: str = "127.0.0.1:8902",
user: str = "root",
password: str = "",
tenant: str = "cnosdb",
database: str = "public",
) -> SQLDatabase:
"""用于从CnosDB连接创建SQLDatabase实例的类方法。
此方法需要'cnos-connector'包。如果未安装,可以使用`pip install cnos-connector`添加。
参数:
url(str):CnosDB服务的HTTP连接主机名和端口号,不包括"http://"或"https://",默认值为"127.0.0.1:8902"。
user(str):用于连接到CnosDB服务的用户名,默认值为"root"。
password(str):连接到CnosDB服务的用户的密码,默认值为""。
tenant(str):用于连接到CnosDB服务的租户名称,默认值为"cnosdb"。
database(str):CnosDB租户中数据库的名称。
返回:
SQLDatabase:使用提供的CnosDB连接详细信息配置的SQLDatabase实例。
"""
try:
from cnosdb_connector import make_cnosdb_langchain_uri
uri = make_cnosdb_langchain_uri(url, user, password, tenant, database)
return cls.from_uri(database_uri=uri)
except ImportError:
raise ImportError(
"cnos-connector package not found, please install with"
" `pip install cnos-connector`"
)
@property
def dialect(self) -> str:
"""返回要使用的方言的字符串表示。"""
return self._engine.dialect.name
[docs] def get_usable_table_names(self) -> Iterable[str]:
"""获取可用表的名称。"""
if self._include_tables:
return sorted(self._include_tables)
return sorted(self._all_tables - self._ignore_tables)
[docs] @deprecated("0.0.1", alternative="get_usable_table_names", removal="0.3.0")
def get_table_names(self) -> Iterable[str]:
"""获取可用表的名称。"""
return self.get_usable_table_names()
@property
def table_info(self) -> str:
"""数据库中所有表的信息。"""
return self.get_table_info()
[docs] def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
"""获取指定表的信息。
遵循Rajkumar等人在2022年指定的最佳实践(https://arxiv.org/abs/2204.00498)。
如果`sample_rows_in_table_info`,则将指定数量的样本行附加到每个表描述中。 正如论文中所示,这可以提高性能。
"""
all_table_names = self.get_usable_table_names()
if table_names is not None:
missing_tables = set(table_names).difference(all_table_names)
if missing_tables:
raise ValueError(f"table_names {missing_tables} not found in database")
all_table_names = table_names
metadata_table_names = [tbl.name for tbl in self._metadata.sorted_tables]
to_reflect = set(all_table_names) - set(metadata_table_names)
if to_reflect:
self._metadata.reflect(
views=self._view_support,
bind=self._engine,
only=list(to_reflect),
schema=self._schema,
)
meta_tables = [
tbl
for tbl in self._metadata.sorted_tables
if tbl.name in set(all_table_names)
and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_"))
]
tables = []
for table in meta_tables:
if self._custom_table_info and table.name in self._custom_table_info:
tables.append(self._custom_table_info[table.name])
continue
# Ignore JSON datatyped columns
for k, v in table.columns.items():
if type(v.type) is NullType:
table._columns.remove(v)
# add create table command
create_table = str(CreateTable(table).compile(self._engine))
table_info = f"{create_table.rstrip()}"
has_extra_info = (
self._indexes_in_table_info or self._sample_rows_in_table_info
)
if has_extra_info:
table_info += "\n\n/*"
if self._indexes_in_table_info:
table_info += f"\n{self._get_table_indexes(table)}\n"
if self._sample_rows_in_table_info:
table_info += f"\n{self._get_sample_rows(table)}\n"
if has_extra_info:
table_info += "*/"
tables.append(table_info)
tables.sort()
final_str = "\n\n".join(tables)
return final_str
def _get_table_indexes(self, table: Table) -> str:
indexes = self._inspector.get_indexes(table.name)
indexes_formatted = "\n".join(map(_format_index, indexes))
return f"Table Indexes:\n{indexes_formatted}"
def _get_sample_rows(self, table: Table) -> str:
# build the select command
command = select(table).limit(self._sample_rows_in_table_info)
# save the columns in string format
columns_str = "\t".join([col.name for col in table.columns])
try:
# get the sample rows
with self._engine.connect() as connection:
sample_rows_result = connection.execute(command) # type: ignore
# shorten values in the sample rows
sample_rows = list(
map(lambda ls: [str(i)[:100] for i in ls], sample_rows_result)
)
# save the sample rows in string format
sample_rows_str = "\n".join(["\t".join(row) for row in sample_rows])
# in some dialects when there are no rows in the table a
# 'ProgrammingError' is returned
except ProgrammingError:
sample_rows_str = ""
return (
f"{self._sample_rows_in_table_info} rows from {table.name} table:\n"
f"{columns_str}\n"
f"{sample_rows_str}"
)
def _execute(
self,
command: Union[str, Executable],
fetch: Literal["all", "one", "cursor"] = "all",
*,
parameters: Optional[Dict[str, Any]] = None,
execution_options: Optional[Dict[str, Any]] = None,
) -> Union[Sequence[Dict[str, Any]], Result]:
"""通过底层引擎执行SQL命令。
如果语句没有返回任何行,则返回一个空列表。
"""
parameters = parameters or {}
execution_options = execution_options or {}
with self._engine.begin() as connection: # type: Connection # type: ignore[name-defined]
if self._schema is not None:
if self.dialect == "snowflake":
connection.exec_driver_sql(
"ALTER SESSION SET search_path = %s",
(self._schema,),
execution_options=execution_options,
)
elif self.dialect == "bigquery":
connection.exec_driver_sql(
"SET @@dataset_id=?",
(self._schema,),
execution_options=execution_options,
)
elif self.dialect == "mssql":
pass
elif self.dialect == "trino":
connection.exec_driver_sql(
"USE ?",
(self._schema,),
execution_options=execution_options,
)
elif self.dialect == "duckdb":
# Unclear which parameterized argument syntax duckdb supports.
# The docs for the duckdb client say they support multiple,
# but `duckdb_engine` seemed to struggle with all of them:
# https://github.com/Mause/duckdb_engine/issues/796
connection.exec_driver_sql(
f"SET search_path TO {self._schema}",
execution_options=execution_options,
)
elif self.dialect == "oracle":
connection.exec_driver_sql(
f"ALTER SESSION SET CURRENT_SCHEMA = {self._schema}",
execution_options=execution_options,
)
elif self.dialect == "sqlany":
# If anybody using Sybase SQL anywhere database then it should not
# go to else condition. It should be same as mssql.
pass
elif self.dialect == "postgresql": # postgresql
connection.exec_driver_sql(
"SET search_path TO %s",
(self._schema,),
execution_options=execution_options,
)
if isinstance(command, str):
command = text(command)
elif isinstance(command, Executable):
pass
else:
raise TypeError(f"Query expression has unknown type: {type(command)}")
cursor = connection.execute(
command,
parameters,
execution_options=execution_options,
)
if cursor.returns_rows:
if fetch == "all":
result = [x._asdict() for x in cursor.fetchall()]
elif fetch == "one":
first_result = cursor.fetchone()
result = [] if first_result is None else [first_result._asdict()]
elif fetch == "cursor":
return cursor
else:
raise ValueError(
"Fetch parameter must be either 'one', 'all', or 'cursor'"
)
return result
return []
[docs] def run(
self,
command: Union[str, Executable],
fetch: Literal["all", "one", "cursor"] = "all",
include_columns: bool = False,
*,
parameters: Optional[Dict[str, Any]] = None,
execution_options: Optional[Dict[str, Any]] = None,
) -> Union[str, Sequence[Dict[str, Any]], Result[Any]]:
"""执行一个SQL命令并返回表示结果的字符串。
如果语句返回行,则返回结果的字符串。
如果语句没有返回行,则返回空字符串。
"""
result = self._execute(
command, fetch, parameters=parameters, execution_options=execution_options
)
if fetch == "cursor":
return result
res = [
{
column: truncate_word(value, length=self._max_string_length)
for column, value in r.items()
}
for r in result
]
if not include_columns:
res = [tuple(row.values()) for row in res] # type: ignore[misc]
if not res:
return ""
else:
return str(res)
[docs] def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str:
"""获取指定表的信息。
遵循Rajkumar等人在2022年指定的最佳实践(https://arxiv.org/abs/2204.00498)。
如果`sample_rows_in_table_info`,则将指定数量的样本行附加到每个表描述中。 正如论文中所示,这可以提高性能。
"""
try:
return self.get_table_info(table_names)
except ValueError as e:
"""Format the error message"""
return f"Error: {e}"
[docs] def run_no_throw(
self,
command: str,
fetch: Literal["all", "one"] = "all",
include_columns: bool = False,
*,
parameters: Optional[Dict[str, Any]] = None,
execution_options: Optional[Dict[str, Any]] = None,
) -> Union[str, Sequence[Dict[str, Any]], Result[Any]]:
"""执行一个SQL命令并返回表示结果的字符串。
如果语句返回行,则返回结果的字符串。
如果语句未返回行,则返回空字符串。
如果语句抛出错误,则返回错误消息。
"""
try:
return self.run(
command,
fetch,
parameters=parameters,
execution_options=execution_options,
include_columns=include_columns,
)
except SQLAlchemyError as e:
"""Format the error message"""
return f"Error: {e}"
[docs] def get_context(self) -> Dict[str, Any]:
"""返回可能在代理提示中需要的数据库上下文。"""
table_names = list(self.get_usable_table_names())
table_info = self.get_table_info_no_throw()
return {"table_info": table_info, "table_names": ", ".join(table_names)}