"""Apache Cassandra数据库封装。"""
from __future__ import annotations
import re
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
if TYPE_CHECKING:
from cassandra.cluster import ResultSet, Session
IGNORED_KEYSPACES = [
"system",
"system_auth",
"system_distributed",
"system_schema",
"system_traces",
"system_views",
"datastax_sla",
"data_endpoint_auth",
]
[docs]class CassandraDatabase:
"""Apache Cassandra®数据库包装器。"""
[docs] def __init__(
self,
session: Optional[Session] = None,
exclude_tables: Optional[List[str]] = None,
include_tables: Optional[List[str]] = None,
cassio_init_kwargs: Optional[Dict[str, Any]] = None,
):
_session = self._resolve_session(session, cassio_init_kwargs)
if not _session:
raise ValueError("Session not provided and cannot be resolved")
self._session = _session
self._exclude_keyspaces = IGNORED_KEYSPACES
self._exclude_tables = exclude_tables or []
self._include_tables = include_tables or []
[docs] def run(
self,
query: str,
fetch: str = "all",
**kwargs: Any,
) -> Union[list, Dict[str, Any], ResultSet]:
"""执行一个CQL查询并返回结果。"""
if fetch == "all":
return self.fetch_all(query, **kwargs)
elif fetch == "one":
return self.fetch_one(query, **kwargs)
elif fetch == "cursor":
return self._fetch(query, **kwargs)
else:
raise ValueError("Fetch parameter must be either 'one', 'all', or 'cursor'")
def _fetch(self, query: str, **kwargs: Any) -> ResultSet:
clean_query = self._validate_cql(query, "SELECT")
return self._session.execute(clean_query, **kwargs)
[docs] def fetch_all(self, query: str, **kwargs: Any) -> list:
return list(self._fetch(query, **kwargs))
[docs] def fetch_one(self, query: str, **kwargs: Any) -> Dict[str, Any]:
result = self._fetch(query, **kwargs)
return result.one()._asdict() if result else {}
[docs] def get_keyspace_tables(self, keyspace: str) -> List[Table]:
"""获取指定keyspace的Table对象。"""
schema = self._resolve_schema([keyspace])
if keyspace in schema:
return schema[keyspace]
else:
return []
# This is a more basic string building function that doesn't use a query builder
# or prepared statements
# TODO: Refactor to use prepared statements
[docs] def get_table_data(
self, keyspace: str, table: str, predicate: str, limit: int
) -> str:
"""从指定keyspace中的指定表中获取数据。"""
query = f"SELECT * FROM {keyspace}.{table}"
if predicate:
query += f" WHERE {predicate}"
if limit:
query += f" LIMIT {limit}"
query += ";"
result = self.fetch_all(query)
data = "\n".join(str(row) for row in result)
return data
[docs] def get_context(self) -> Dict[str, Any]:
"""返回可能在代理提示中需要的数据库上下文。"""
keyspaces = self._fetch_keyspaces()
return {"keyspaces": ", ".join(keyspaces)}
def _validate_cql(self, cql: str, type: str = "SELECT") -> str:
"""验证CQL查询字符串的基本格式和安全性检查。
确保`cql`以指定类型(例如SELECT)开头,并且不包含可能指示CQL注入漏洞的内容。
参数:
cql:要验证的CQL查询字符串。
type:查询的预期起始关键字,用于验证查询是否以正确的操作类型(例如"SELECT","UPDATE")开头。默认为"SELECT"。
返回:
修剪和经过验证的CQL查询字符串,不带分号结尾。
引发:
ValueError:如果`type`的值不受支持
DatabaseError:如果认为`cql`不安全
"""
SUPPORTED_TYPES = ["SELECT"]
if type and type.upper() not in SUPPORTED_TYPES:
raise ValueError(
f"""Unsupported CQL type: {type}. Supported types:
{SUPPORTED_TYPES}"""
)
# Basic sanity checks
cql_trimmed = cql.strip()
if not cql_trimmed.upper().startswith(type.upper()):
raise DatabaseError(f"CQL must start with {type.upper()}.")
# Allow a trailing semicolon, but remove (it is optional with the Python driver)
cql_trimmed = cql_trimmed.rstrip(";")
# Consider content within matching quotes to be "safe"
# Remove single-quoted strings
cql_sanitized = re.sub(r"'.*?'", "", cql_trimmed)
# Remove double-quoted strings
cql_sanitized = re.sub(r'".*?"', "", cql_sanitized)
# Find unsafe content in the remaining CQL
if ";" in cql_sanitized:
raise DatabaseError(
"""Potentially unsafe CQL, as it contains a ; at a
place other than the end or within quotation marks."""
)
# The trimmed query, before modifications
return cql_trimmed
def _fetch_keyspaces(self, keyspaces: Optional[List[str]] = None) -> List[str]:
"""从Cassandra数据库中获取一个键空间名称列表。该列表可以通过提供的键空间名称列表或排除预定义键空间来进行过滤。
参数:
keyspaces: 要特别包含的键空间名称列表。
如果提供且不为空,则该方法仅返回此列表中存在的键空间。
如果未提供或为空,则该方法返回除了在_exclude_keyspaces属性中指定的键空间之外的所有键空间。
返回:
根据过滤条件的键空间名称列表。
"""
all_keyspaces = self.fetch_all(
"SELECT keyspace_name FROM system_schema.keyspaces"
)
# Filtering keyspaces based on 'keyspace_list' and '_exclude_keyspaces'
filtered_keyspaces = []
for ks in all_keyspaces:
if not isinstance(ks, Dict):
continue # Skip if the row is not a dictionary.
keyspace_name = ks["keyspace_name"]
if keyspaces and keyspace_name in keyspaces:
filtered_keyspaces.append(keyspace_name)
elif not keyspaces and keyspace_name not in self._exclude_keyspaces:
filtered_keyspaces.append(keyspace_name)
return filtered_keyspaces
def _format_keyspace_query(self, query: str, keyspaces: List[str]) -> str:
# Construct IN clause for CQL query
keyspace_in_clause = ", ".join([f"'{ks}'" for ks in keyspaces])
return f"""{query} WHERE keyspace_name IN ({keyspace_in_clause})"""
def _fetch_tables_data(self, keyspaces: List[str]) -> list:
"""获取由一组keyspaces筛选的表模式数据。
该方法允许有效地获取多个keyspaces的模式信息,从而使应用程序能够以编程方式分析或记录数据库模式。
参数:
keyspaces:要获取表模式数据的keyspace名称列表。
返回:
表详细信息的字典(keyspace名称、表名称和注释)。
"""
tables_query = self._format_keyspace_query(
"SELECT keyspace_name, table_name, comment FROM system_schema.tables",
keyspaces,
)
return self.fetch_all(tables_query)
def _fetch_columns_data(self, keyspaces: List[str]) -> list:
"""获取按keyspaces列表过滤的列模式数据。
该方法允许有效地在单个操作中获取多个keyspaces的模式信息,使应用程序能够以编程方式分析或记录数据库模式。
参数:
keyspaces:要从中获取表模式数据的keyspace名称列表。
返回:
列详细信息的字典(keyspace名称,表名称,列名称,类型,种类和位置)。
"""
tables_query = self._format_keyspace_query(
"""
SELECT keyspace_name, table_name, column_name, type, kind,
clustering_order, position
FROM system_schema.columns
""",
keyspaces,
)
return self.fetch_all(tables_query)
def _fetch_indexes_data(self, keyspaces: List[str]) -> list:
"""获取按keyspaces列表过滤的索引模式数据。
该方法允许在单个操作中高效地获取多个keyspaces的模式信息,使应用程序能够以编程方式分析或记录数据库模式。
参数:
keyspaces:要获取表模式数据的keyspace名称列表。
返回:
索引详细信息的字典(keyspace名称、表名称、索引名称、类型和选项)。
"""
tables_query = self._format_keyspace_query(
"""
SELECT keyspace_name, table_name, index_name,
kind, options
FROM system_schema.indexes
""",
keyspaces,
)
return self.fetch_all(tables_query)
def _resolve_schema(
self, keyspaces: Optional[List[str]] = None
) -> Dict[str, List[Table]]:
"""高效地获取和整理Cassandra表模式信息,如注释、列和索引,将其组织成一个字典,将keyspace名称映射到Table对象的列表。
参数:
keyspaces: 一个可选的keyspace名称列表,用于获取表模式数据。
返回:
一个字典,以keyspace名称为键,以Table对象的列表为值,其中每个Table对象都填充有适用于其keyspace和表名称的模式详细信息。
"""
if not keyspaces:
keyspaces = self._fetch_keyspaces()
tables_data = self._fetch_tables_data(keyspaces)
columns_data = self._fetch_columns_data(keyspaces)
indexes_data = self._fetch_indexes_data(keyspaces)
keyspace_dict: dict = {}
for table_data in tables_data:
keyspace = table_data.keyspace_name
table_name = table_data.table_name
comment = table_data.comment
if self._include_tables and table_name not in self._include_tables:
continue
if self._exclude_tables and table_name in self._exclude_tables:
continue
# Filter columns and indexes for this table
table_columns = [
(c.column_name, c.type)
for c in columns_data
if c.keyspace_name == keyspace and c.table_name == table_name
]
partition_keys = [
c.column_name
for c in columns_data
if c.kind == "partition_key"
and c.keyspace_name == keyspace
and c.table_name == table_name
]
clustering_keys = [
(c.column_name, c.clustering_order)
for c in columns_data
if c.kind == "clustering"
and c.keyspace_name == keyspace
and c.table_name == table_name
]
table_indexes = [
(c.index_name, c.kind, c.options)
for c in indexes_data
if c.keyspace_name == keyspace and c.table_name == table_name
]
table_obj = Table(
keyspace=keyspace,
table_name=table_name,
comment=comment,
columns=table_columns,
partition=partition_keys,
clustering=clustering_keys,
indexes=table_indexes,
)
if keyspace not in keyspace_dict:
keyspace_dict[keyspace] = []
keyspace_dict[keyspace].append(table_obj)
return keyspace_dict
@staticmethod
def _resolve_session(
session: Optional[Session] = None,
cassio_init_kwargs: Optional[Dict[str, Any]] = None,
) -> Optional[Session]:
"""尝试解析并返回一个用于数据库操作的Session对象。
此函数遵循特定的优先顺序来确定要使用的适当会话:
1. 如果提供了`session`参数,则使用该参数,
2. 使用现有的`cassio`会话,
3. 使用从`cassio_init_kwargs`派生的新的`cassio`会话,
4. 返回`None`
参数:
session: 可选的直接使用的会话。
cassio_init_kwargs: 传递给`cassio`的可选关键字参数字典。
返回:
如果成功,返回解析后的会话对象,如果无法解析会话,则返回`None`。
Raises:
ValueError: 如果提供了`cassio_init_kwargs`但不是关键字参数字典。
"""
# Prefer given session
if session:
return session
# If a session is not provided, create one using cassio if available
# dynamically import cassio to avoid circular imports
try:
import cassio.config
except ImportError:
raise ValueError(
"cassio package not found, please install with" " `pip install cassio`"
)
# Use pre-existing session on cassio
s = cassio.config.resolve_session()
if s:
return s
# Try to init and return cassio session
if cassio_init_kwargs:
if isinstance(cassio_init_kwargs, dict):
cassio.init(**cassio_init_kwargs)
s = cassio.config.check_resolve_session()
return s
else:
raise ValueError("cassio_init_kwargs must be a keyword dictionary")
# return None if we're not able to resolve
return None
[docs]class DatabaseError(Exception):
"""数据库模式中出现错误时引发的异常。
属性:
message -- 错误的解释"""
def __init__(self, message: str):
self.message = message
super().__init__(self.message)
[docs]class Table(BaseModel):
keyspace: str
"""表存在的键空间。"""
table_name: str
"""表的名称。"""
comment: Optional[str] = None
"""与表相关的注释。"""
columns: List[Tuple[str, str]] = Field(default_factory=list)
partition: List[str] = Field(default_factory=list)
clustering: List[Tuple[str, str]] = Field(default_factory=list)
indexes: List[Tuple[str, str, str]] = Field(default_factory=list)
class Config:
frozen = True
@root_validator()
def check_required_fields(cls, class_values: dict) -> dict:
if not class_values["columns"]:
raise ValueError("non-empty column list for must be provided")
if not class_values["partition"]:
raise ValueError("non-empty partition list must be provided")
return class_values
[docs] @classmethod
def from_database(
cls, keyspace: str, table_name: str, db: CassandraDatabase
) -> Table:
columns, partition, clustering = cls._resolve_columns(keyspace, table_name, db)
return cls(
keyspace=keyspace,
table_name=table_name,
comment=cls._resolve_comment(keyspace, table_name, db),
columns=columns,
partition=partition,
clustering=clustering,
indexes=cls._resolve_indexes(keyspace, table_name, db),
)
[docs] def as_markdown(
self, include_keyspace: bool = True, header_level: Optional[int] = None
) -> str:
"""生成Cassandra表模式的Markdown表示,允许为表名部分设置可定制的标题级别。
参数:
include_keyspace:如果为True,则在输出中包含keyspace。
默认为True。
header_level:指定表名的Markdown标题级别。
如果为None,则表名将不包含标题。
默认为None(无标题级别)。
返回:
以Markdown格式返回一个字符串,详细描述表名(可选标题级别)、keyspace(可选)、注释、列、分区键、聚簇键(可选排序方式)和索引。
"""
output = ""
if header_level is not None:
output += f"{'#' * header_level} "
output += f"Table Name: {self.table_name}\n"
if include_keyspace:
output += f"- Keyspace: {self.keyspace}\n"
if self.comment:
output += f"- Comment: {self.comment}\n"
output += "- Columns\n"
for column, type in self.columns:
output += f" - {column} ({type})\n"
output += f"- Partition Keys: ({', '.join(self.partition)})\n"
output += "- Clustering Keys: "
if self.clustering:
cluster_list = []
for column, clustering_order in self.clustering:
if clustering_order.lower() == "none":
cluster_list.append(column)
else:
cluster_list.append(f"{column} {clustering_order}")
output += f"({', '.join(cluster_list)})\n"
if self.indexes:
output += "- Indexes\n"
for name, kind, options in self.indexes:
output += f" - {name} : kind={kind}, options={options}\n"
return output
@staticmethod
def _resolve_comment(
keyspace: str, table_name: str, db: CassandraDatabase
) -> Optional[str]:
result = db.run(
f"""SELECT comment
FROM system_schema.tables
WHERE keyspace_name = '{keyspace}'
AND table_name = '{table_name}';""",
fetch="one",
)
if isinstance(result, dict):
comment = result.get("comment")
if comment:
return comment
else:
return None # Default comment if none is found
else:
raise ValueError(
f"""Unexpected result type from db.run:
{type(result).__name__}"""
)
@staticmethod
def _resolve_columns(
keyspace: str, table_name: str, db: CassandraDatabase
) -> Tuple[List[Tuple[str, str]], List[str], List[Tuple[str, str]]]:
columns = []
partition_info = []
cluster_info = []
results = db.run(
f"""SELECT column_name, type, kind, clustering_order, position
FROM system_schema.columns
WHERE keyspace_name = '{keyspace}'
AND table_name = '{table_name}';"""
)
# Type check to ensure 'results' is a sequence of dictionaries.
if not isinstance(results, Sequence):
raise TypeError("Expected a sequence of dictionaries from 'run' method.")
for row in results:
if not isinstance(row, Dict):
continue # Skip if the row is not a dictionary.
columns.append((row["column_name"], row["type"]))
if row["kind"] == "partition_key":
partition_info.append((row["column_name"], row["position"]))
elif row["kind"] == "clustering":
cluster_info.append(
(row["column_name"], row["clustering_order"], row["position"])
)
partition = [
column_name for column_name, _ in sorted(partition_info, key=lambda x: x[1])
]
cluster = [
(column_name, clustering_order)
for column_name, clustering_order, _ in sorted(
cluster_info, key=lambda x: x[2]
)
]
return columns, partition, cluster
@staticmethod
def _resolve_indexes(
keyspace: str, table_name: str, db: CassandraDatabase
) -> List[Tuple[str, str, str]]:
indexes = []
results = db.run(
f"""SELECT index_name, kind, options
FROM system_schema.indexes
WHERE keyspace_name = '{keyspace}'
AND table_name = '{table_name}';"""
)
# Type check to ensure 'results' is a sequence of dictionaries
if not isinstance(results, Sequence):
raise TypeError("Expected a sequence of dictionaries from 'run' method.")
for row in results:
if not isinstance(row, Dict):
continue # Skip if the row is not a dictionary.
# Convert 'options' to string if it's not already,
# assuming it's JSON-like and needs conversion
index_options = row["options"]
if not isinstance(index_options, str):
# Assuming index_options needs to be serialized or simply converted
index_options = str(index_options)
indexes.append((row["index_name"], row["kind"], index_options))
return indexes