Source code for langchain_community.utilities.spark_sql
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Iterable, List, Optional
if TYPE_CHECKING:
from pyspark.sql import DataFrame, Row, SparkSession
[docs]class SparkSQL:
"""SparkSQL是一个用于与Spark SQL交互的实用类。"""
[docs] def __init__(
self,
spark_session: Optional[SparkSession] = None,
catalog: Optional[str] = None,
schema: Optional[str] = None,
ignore_tables: Optional[List[str]] = None,
include_tables: Optional[List[str]] = None,
sample_rows_in_table_info: int = 3,
):
"""初始化一个SparkSQL对象。
参数:
spark_session: 一个SparkSession对象。
如果未提供,将会创建一个。
catalog: 要使用的目录。
如果未提供,将使用默认目录。
schema: 要使用的模式。
如果未提供,将使用默认模式。
ignore_tables: 要忽略的表的列表。
如果未提供,将使用所有表。
include_tables: 要包含的表的列表。
如果未提供,将使用所有表。
sample_rows_in_table_info: 在表信息中包含的行数。
默认为3。
"""
try:
from pyspark.sql import SparkSession
except ImportError:
raise ImportError(
"pyspark is not installed. Please install it with `pip install pyspark`"
)
self._spark = (
spark_session if spark_session else SparkSession.builder.getOrCreate()
)
if catalog is not None:
self._spark.catalog.setCurrentCatalog(catalog)
if schema is not None:
self._spark.catalog.setCurrentDatabase(schema)
self._all_tables = set(self._get_all_table_names())
self._include_tables = set(include_tables) if include_tables else set()
if self._include_tables:
missing_tables = self._include_tables - self._all_tables
if missing_tables:
raise ValueError(
f"include_tables {missing_tables} not found in database"
)
self._ignore_tables = set(ignore_tables) if ignore_tables else set()
if self._ignore_tables:
missing_tables = self._ignore_tables - self._all_tables
if missing_tables:
raise ValueError(
f"ignore_tables {missing_tables} not found in database"
)
usable_tables = self.get_usable_table_names()
self._usable_tables = set(usable_tables) if usable_tables else self._all_tables
if not isinstance(sample_rows_in_table_info, int):
raise TypeError("sample_rows_in_table_info must be an integer")
self._sample_rows_in_table_info = sample_rows_in_table_info
[docs] @classmethod
def from_uri(
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
) -> SparkSQL:
"""通过Spark连接创建远程Spark会话。
例如:SparkSQL.from_uri("sc://localhost:15002")
"""
try:
from pyspark.sql import SparkSession
except ImportError:
raise ImportError(
"pyspark is not installed. Please install it with `pip install pyspark`"
)
spark = SparkSession.builder.remote(database_uri).getOrCreate()
return cls(spark, **kwargs)
[docs] def get_usable_table_names(self) -> Iterable[str]:
"""获取可用表的名称。"""
if self._include_tables:
return self._include_tables
# sorting the result can help LLM understanding it.
return sorted(self._all_tables - self._ignore_tables)
def _get_all_table_names(self) -> Iterable[str]:
rows = self._spark.sql("SHOW TABLES").select("tableName").collect()
return list(map(lambda row: row.tableName, rows))
def _get_create_table_stmt(self, table: str) -> str:
statement = (
self._spark.sql(f"SHOW CREATE TABLE {table}").collect()[0].createtab_stmt
)
# Ignore the data source provider and options to reduce the number of tokens.
using_clause_index = statement.find("USING")
return statement[:using_clause_index] + ";"
[docs] def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
all_table_names = self.get_usable_table_names()
if table_names is not None:
missing_tables = set(table_names).difference(all_table_names)
if missing_tables:
raise ValueError(f"table_names {missing_tables} not found in database")
all_table_names = table_names
tables = []
for table_name in all_table_names:
table_info = self._get_create_table_stmt(table_name)
if self._sample_rows_in_table_info:
table_info += "\n\n/*"
table_info += f"\n{self._get_sample_spark_rows(table_name)}\n"
table_info += "*/"
tables.append(table_info)
final_str = "\n\n".join(tables)
return final_str
def _get_sample_spark_rows(self, table: str) -> str:
query = f"SELECT * FROM {table} LIMIT {self._sample_rows_in_table_info}"
df = self._spark.sql(query)
columns_str = "\t".join(list(map(lambda f: f.name, df.schema.fields)))
try:
sample_rows = self._get_dataframe_results(df)
# save the sample rows in string format
sample_rows_str = "\n".join(["\t".join(row) for row in sample_rows])
except Exception:
sample_rows_str = ""
return (
f"{self._sample_rows_in_table_info} rows from {table} table:\n"
f"{columns_str}\n"
f"{sample_rows_str}"
)
def _convert_row_as_tuple(self, row: Row) -> tuple:
return tuple(map(str, row.asDict().values()))
def _get_dataframe_results(self, df: DataFrame) -> list:
return list(map(self._convert_row_as_tuple, df.collect()))
[docs] def run(self, command: str, fetch: str = "all") -> str:
df = self._spark.sql(command)
if fetch == "one":
df = df.limit(1)
return str(self._get_dataframe_results(df))
[docs] def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str:
"""获取指定表的信息。
遵循Rajkumar等人在2022年指定的最佳实践(https://arxiv.org/abs/2204.00498)。
如果`sample_rows_in_table_info`,则将指定数量的样本行附加到每个表描述中。 正如论文中所示,这可以提高性能。
"""
try:
return self.get_table_info(table_names)
except ValueError as e:
"""Format the error message"""
return f"Error: {e}"
[docs] def run_no_throw(self, command: str, fetch: str = "all") -> str:
"""执行一个SQL命令并返回表示结果的字符串。
如果语句返回行,则返回结果的字符串。
如果语句未返回行,则返回空字符串。
如果语句抛出错误,则返回错误消息。
"""
try:
return self.run(command, fetch)
except Exception as e:
"""Format the error message"""
return f"Error: {e}"