Source code for langchain_community.utilities.cassandra_database

"""Apache Cassandra数据库封装。"""
from __future__ import annotations

import re
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union

from langchain_core.pydantic_v1 import BaseModel, Field, root_validator

if TYPE_CHECKING:
    from cassandra.cluster import ResultSet, Session

IGNORED_KEYSPACES = [
    "system",
    "system_auth",
    "system_distributed",
    "system_schema",
    "system_traces",
    "system_views",
    "datastax_sla",
    "data_endpoint_auth",
]


[docs]class CassandraDatabase: """Apache Cassandra®数据库包装器。"""
[docs] def __init__( self, session: Optional[Session] = None, exclude_tables: Optional[List[str]] = None, include_tables: Optional[List[str]] = None, cassio_init_kwargs: Optional[Dict[str, Any]] = None, ): _session = self._resolve_session(session, cassio_init_kwargs) if not _session: raise ValueError("Session not provided and cannot be resolved") self._session = _session self._exclude_keyspaces = IGNORED_KEYSPACES self._exclude_tables = exclude_tables or [] self._include_tables = include_tables or []
[docs] def run( self, query: str, fetch: str = "all", **kwargs: Any, ) -> Union[list, Dict[str, Any], ResultSet]: """执行一个CQL查询并返回结果。""" if fetch == "all": return self.fetch_all(query, **kwargs) elif fetch == "one": return self.fetch_one(query, **kwargs) elif fetch == "cursor": return self._fetch(query, **kwargs) else: raise ValueError("Fetch parameter must be either 'one', 'all', or 'cursor'")
def _fetch(self, query: str, **kwargs: Any) -> ResultSet: clean_query = self._validate_cql(query, "SELECT") return self._session.execute(clean_query, **kwargs)
[docs] def fetch_all(self, query: str, **kwargs: Any) -> list: return list(self._fetch(query, **kwargs))
[docs] def fetch_one(self, query: str, **kwargs: Any) -> Dict[str, Any]: result = self._fetch(query, **kwargs) return result.one()._asdict() if result else {}
[docs] def get_keyspace_tables(self, keyspace: str) -> List[Table]: """获取指定keyspace的Table对象。""" schema = self._resolve_schema([keyspace]) if keyspace in schema: return schema[keyspace] else: return []
# This is a more basic string building function that doesn't use a query builder # or prepared statements # TODO: Refactor to use prepared statements
[docs] def get_table_data( self, keyspace: str, table: str, predicate: str, limit: int ) -> str: """从指定keyspace中的指定表中获取数据。""" query = f"SELECT * FROM {keyspace}.{table}" if predicate: query += f" WHERE {predicate}" if limit: query += f" LIMIT {limit}" query += ";" result = self.fetch_all(query) data = "\n".join(str(row) for row in result) return data
[docs] def get_context(self) -> Dict[str, Any]: """返回可能在代理提示中需要的数据库上下文。""" keyspaces = self._fetch_keyspaces() return {"keyspaces": ", ".join(keyspaces)}
[docs] def format_keyspace_to_markdown( self, keyspace: str, tables: Optional[List[Table]] = None ) -> str: """生成特定keyspace的模式的markdown表示形式,通过迭代该keyspace中的所有表,并调用它们的as_markdown方法。 参数: keyspace:要为其生成markdown文档的keyspace的名称。 tables:keyspace中的表列表;如果未提供,将被解析。 返回: 包含指定keyspace模式的markdown表示形式的字符串。 """ if not tables: tables = self.get_keyspace_tables(keyspace) if tables: output = f"## Keyspace: {keyspace}\n\n" if tables: for table in tables: output += table.as_markdown(include_keyspace=False, header_level=3) output += "\n\n" else: output += "No tables present in keyspace\n\n" return output else: return ""
[docs] def format_schema_to_markdown(self) -> str: """生成CassandraDatabase实例中所有keyspaces和tables的模式的markdown表示。此方法利用format_keyspace_to_markdown方法为每个keyspace创建markdown部分,将它们组装成一个全面的模式文档。 遍历数据库中的每个keyspace,利用format_keyspace_to_markdown为每个keyspace的模式生成markdown,包括其表的详细信息。这些部分被连接起来形成一个单一的markdown文档,代表整个数据库或在此实例中已解析的keyspaces子集的模式。 返回: 一个markdown字符串,记录此CassandraDatabase实例中所有已解析keyspaces及其tables的模式。这包括keyspace名称、table名称、注释、列、分区键、聚簇键和每个table的索引。 """ schema = self._resolve_schema() output = "# Cassandra Database Schema\n\n" for keyspace, tables in schema.items(): output += f"{self.format_keyspace_to_markdown(keyspace, tables)}\n\n" return output
def _validate_cql(self, cql: str, type: str = "SELECT") -> str: """验证CQL查询字符串的基本格式和安全性检查。 确保`cql`以指定类型(例如SELECT)开头,并且不包含可能指示CQL注入漏洞的内容。 参数: cql:要验证的CQL查询字符串。 type:查询的预期起始关键字,用于验证查询是否以正确的操作类型(例如"SELECT","UPDATE")开头。默认为"SELECT"。 返回: 修剪和经过验证的CQL查询字符串,不带分号结尾。 引发: ValueError:如果`type`的值不受支持 DatabaseError:如果认为`cql`不安全 """ SUPPORTED_TYPES = ["SELECT"] if type and type.upper() not in SUPPORTED_TYPES: raise ValueError( f"""Unsupported CQL type: {type}. Supported types: {SUPPORTED_TYPES}""" ) # Basic sanity checks cql_trimmed = cql.strip() if not cql_trimmed.upper().startswith(type.upper()): raise DatabaseError(f"CQL must start with {type.upper()}.") # Allow a trailing semicolon, but remove (it is optional with the Python driver) cql_trimmed = cql_trimmed.rstrip(";") # Consider content within matching quotes to be "safe" # Remove single-quoted strings cql_sanitized = re.sub(r"'.*?'", "", cql_trimmed) # Remove double-quoted strings cql_sanitized = re.sub(r'".*?"', "", cql_sanitized) # Find unsafe content in the remaining CQL if ";" in cql_sanitized: raise DatabaseError( """Potentially unsafe CQL, as it contains a ; at a place other than the end or within quotation marks.""" ) # The trimmed query, before modifications return cql_trimmed def _fetch_keyspaces(self, keyspaces: Optional[List[str]] = None) -> List[str]: """从Cassandra数据库中获取一个键空间名称列表。该列表可以通过提供的键空间名称列表或排除预定义键空间来进行过滤。 参数: keyspaces: 要特别包含的键空间名称列表。 如果提供且不为空,则该方法仅返回此列表中存在的键空间。 如果未提供或为空,则该方法返回除了在_exclude_keyspaces属性中指定的键空间之外的所有键空间。 返回: 根据过滤条件的键空间名称列表。 """ all_keyspaces = self.fetch_all( "SELECT keyspace_name FROM system_schema.keyspaces" ) # Filtering keyspaces based on 'keyspace_list' and '_exclude_keyspaces' filtered_keyspaces = [] for ks in all_keyspaces: if not isinstance(ks, Dict): continue # Skip if the row is not a dictionary. keyspace_name = ks["keyspace_name"] if keyspaces and keyspace_name in keyspaces: filtered_keyspaces.append(keyspace_name) elif not keyspaces and keyspace_name not in self._exclude_keyspaces: filtered_keyspaces.append(keyspace_name) return filtered_keyspaces def _format_keyspace_query(self, query: str, keyspaces: List[str]) -> str: # Construct IN clause for CQL query keyspace_in_clause = ", ".join([f"'{ks}'" for ks in keyspaces]) return f"""{query} WHERE keyspace_name IN ({keyspace_in_clause})""" def _fetch_tables_data(self, keyspaces: List[str]) -> list: """获取由一组keyspaces筛选的表模式数据。 该方法允许有效地获取多个keyspaces的模式信息,从而使应用程序能够以编程方式分析或记录数据库模式。 参数: keyspaces:要获取表模式数据的keyspace名称列表。 返回: 表详细信息的字典(keyspace名称、表名称和注释)。 """ tables_query = self._format_keyspace_query( "SELECT keyspace_name, table_name, comment FROM system_schema.tables", keyspaces, ) return self.fetch_all(tables_query) def _fetch_columns_data(self, keyspaces: List[str]) -> list: """获取按keyspaces列表过滤的列模式数据。 该方法允许有效地在单个操作中获取多个keyspaces的模式信息,使应用程序能够以编程方式分析或记录数据库模式。 参数: keyspaces:要从中获取表模式数据的keyspace名称列表。 返回: 列详细信息的字典(keyspace名称,表名称,列名称,类型,种类和位置)。 """ tables_query = self._format_keyspace_query( """ SELECT keyspace_name, table_name, column_name, type, kind, clustering_order, position FROM system_schema.columns """, keyspaces, ) return self.fetch_all(tables_query) def _fetch_indexes_data(self, keyspaces: List[str]) -> list: """获取按keyspaces列表过滤的索引模式数据。 该方法允许在单个操作中高效地获取多个keyspaces的模式信息,使应用程序能够以编程方式分析或记录数据库模式。 参数: keyspaces:要获取表模式数据的keyspace名称列表。 返回: 索引详细信息的字典(keyspace名称、表名称、索引名称、类型和选项)。 """ tables_query = self._format_keyspace_query( """ SELECT keyspace_name, table_name, index_name, kind, options FROM system_schema.indexes """, keyspaces, ) return self.fetch_all(tables_query) def _resolve_schema( self, keyspaces: Optional[List[str]] = None ) -> Dict[str, List[Table]]: """高效地获取和整理Cassandra表模式信息,如注释、列和索引,将其组织成一个字典,将keyspace名称映射到Table对象的列表。 参数: keyspaces: 一个可选的keyspace名称列表,用于获取表模式数据。 返回: 一个字典,以keyspace名称为键,以Table对象的列表为值,其中每个Table对象都填充有适用于其keyspace和表名称的模式详细信息。 """ if not keyspaces: keyspaces = self._fetch_keyspaces() tables_data = self._fetch_tables_data(keyspaces) columns_data = self._fetch_columns_data(keyspaces) indexes_data = self._fetch_indexes_data(keyspaces) keyspace_dict: dict = {} for table_data in tables_data: keyspace = table_data.keyspace_name table_name = table_data.table_name comment = table_data.comment if self._include_tables and table_name not in self._include_tables: continue if self._exclude_tables and table_name in self._exclude_tables: continue # Filter columns and indexes for this table table_columns = [ (c.column_name, c.type) for c in columns_data if c.keyspace_name == keyspace and c.table_name == table_name ] partition_keys = [ c.column_name for c in columns_data if c.kind == "partition_key" and c.keyspace_name == keyspace and c.table_name == table_name ] clustering_keys = [ (c.column_name, c.clustering_order) for c in columns_data if c.kind == "clustering" and c.keyspace_name == keyspace and c.table_name == table_name ] table_indexes = [ (c.index_name, c.kind, c.options) for c in indexes_data if c.keyspace_name == keyspace and c.table_name == table_name ] table_obj = Table( keyspace=keyspace, table_name=table_name, comment=comment, columns=table_columns, partition=partition_keys, clustering=clustering_keys, indexes=table_indexes, ) if keyspace not in keyspace_dict: keyspace_dict[keyspace] = [] keyspace_dict[keyspace].append(table_obj) return keyspace_dict @staticmethod def _resolve_session( session: Optional[Session] = None, cassio_init_kwargs: Optional[Dict[str, Any]] = None, ) -> Optional[Session]: """尝试解析并返回一个用于数据库操作的Session对象。 此函数遵循特定的优先顺序来确定要使用的适当会话: 1. 如果提供了`session`参数,则使用该参数, 2. 使用现有的`cassio`会话, 3. 使用从`cassio_init_kwargs`派生的新的`cassio`会话, 4. 返回`None` 参数: session: 可选的直接使用的会话。 cassio_init_kwargs: 传递给`cassio`的可选关键字参数字典。 返回: 如果成功,返回解析后的会话对象,如果无法解析会话,则返回`None`。 Raises: ValueError: 如果提供了`cassio_init_kwargs`但不是关键字参数字典。 """ # Prefer given session if session: return session # If a session is not provided, create one using cassio if available # dynamically import cassio to avoid circular imports try: import cassio.config except ImportError: raise ValueError( "cassio package not found, please install with" " `pip install cassio`" ) # Use pre-existing session on cassio s = cassio.config.resolve_session() if s: return s # Try to init and return cassio session if cassio_init_kwargs: if isinstance(cassio_init_kwargs, dict): cassio.init(**cassio_init_kwargs) s = cassio.config.check_resolve_session() return s else: raise ValueError("cassio_init_kwargs must be a keyword dictionary") # return None if we're not able to resolve return None
[docs]class DatabaseError(Exception): """数据库模式中出现错误时引发的异常。 属性: message -- 错误的解释""" def __init__(self, message: str): self.message = message super().__init__(self.message)
[docs]class Table(BaseModel): keyspace: str """表存在的键空间。""" table_name: str """表的名称。""" comment: Optional[str] = None """与表相关的注释。""" columns: List[Tuple[str, str]] = Field(default_factory=list) partition: List[str] = Field(default_factory=list) clustering: List[Tuple[str, str]] = Field(default_factory=list) indexes: List[Tuple[str, str, str]] = Field(default_factory=list) class Config: frozen = True @root_validator() def check_required_fields(cls, class_values: dict) -> dict: if not class_values["columns"]: raise ValueError("non-empty column list for must be provided") if not class_values["partition"]: raise ValueError("non-empty partition list must be provided") return class_values
[docs] @classmethod def from_database( cls, keyspace: str, table_name: str, db: CassandraDatabase ) -> Table: columns, partition, clustering = cls._resolve_columns(keyspace, table_name, db) return cls( keyspace=keyspace, table_name=table_name, comment=cls._resolve_comment(keyspace, table_name, db), columns=columns, partition=partition, clustering=clustering, indexes=cls._resolve_indexes(keyspace, table_name, db), )
[docs] def as_markdown( self, include_keyspace: bool = True, header_level: Optional[int] = None ) -> str: """生成Cassandra表模式的Markdown表示,允许为表名部分设置可定制的标题级别。 参数: include_keyspace:如果为True,则在输出中包含keyspace。 默认为True。 header_level:指定表名的Markdown标题级别。 如果为None,则表名将不包含标题。 默认为None(无标题级别)。 返回: 以Markdown格式返回一个字符串,详细描述表名(可选标题级别)、keyspace(可选)、注释、列、分区键、聚簇键(可选排序方式)和索引。 """ output = "" if header_level is not None: output += f"{'#' * header_level} " output += f"Table Name: {self.table_name}\n" if include_keyspace: output += f"- Keyspace: {self.keyspace}\n" if self.comment: output += f"- Comment: {self.comment}\n" output += "- Columns\n" for column, type in self.columns: output += f" - {column} ({type})\n" output += f"- Partition Keys: ({', '.join(self.partition)})\n" output += "- Clustering Keys: " if self.clustering: cluster_list = [] for column, clustering_order in self.clustering: if clustering_order.lower() == "none": cluster_list.append(column) else: cluster_list.append(f"{column} {clustering_order}") output += f"({', '.join(cluster_list)})\n" if self.indexes: output += "- Indexes\n" for name, kind, options in self.indexes: output += f" - {name} : kind={kind}, options={options}\n" return output
@staticmethod def _resolve_comment( keyspace: str, table_name: str, db: CassandraDatabase ) -> Optional[str]: result = db.run( f"""SELECT comment FROM system_schema.tables WHERE keyspace_name = '{keyspace}' AND table_name = '{table_name}';""", fetch="one", ) if isinstance(result, dict): comment = result.get("comment") if comment: return comment else: return None # Default comment if none is found else: raise ValueError( f"""Unexpected result type from db.run: {type(result).__name__}""" ) @staticmethod def _resolve_columns( keyspace: str, table_name: str, db: CassandraDatabase ) -> Tuple[List[Tuple[str, str]], List[str], List[Tuple[str, str]]]: columns = [] partition_info = [] cluster_info = [] results = db.run( f"""SELECT column_name, type, kind, clustering_order, position FROM system_schema.columns WHERE keyspace_name = '{keyspace}' AND table_name = '{table_name}';""" ) # Type check to ensure 'results' is a sequence of dictionaries. if not isinstance(results, Sequence): raise TypeError("Expected a sequence of dictionaries from 'run' method.") for row in results: if not isinstance(row, Dict): continue # Skip if the row is not a dictionary. columns.append((row["column_name"], row["type"])) if row["kind"] == "partition_key": partition_info.append((row["column_name"], row["position"])) elif row["kind"] == "clustering": cluster_info.append( (row["column_name"], row["clustering_order"], row["position"]) ) partition = [ column_name for column_name, _ in sorted(partition_info, key=lambda x: x[1]) ] cluster = [ (column_name, clustering_order) for column_name, clustering_order, _ in sorted( cluster_info, key=lambda x: x[2] ) ] return columns, partition, cluster @staticmethod def _resolve_indexes( keyspace: str, table_name: str, db: CassandraDatabase ) -> List[Tuple[str, str, str]]: indexes = [] results = db.run( f"""SELECT index_name, kind, options FROM system_schema.indexes WHERE keyspace_name = '{keyspace}' AND table_name = '{table_name}';""" ) # Type check to ensure 'results' is a sequence of dictionaries if not isinstance(results, Sequence): raise TypeError("Expected a sequence of dictionaries from 'run' method.") for row in results: if not isinstance(row, Dict): continue # Skip if the row is not a dictionary. # Convert 'options' to string if it's not already, # assuming it's JSON-like and needs conversion index_options = row["options"] if not isinstance(index_options, str): # Assuming index_options needs to be serialized or simply converted index_options = str(index_options) indexes.append((row["index_name"], row["kind"], index_options)) return indexes