Source code for langchain_community.graphs.nebula_graph

import logging
from string import Template
from typing import Any, Dict, Optional

logger = logging.getLogger(__name__)

rel_query = Template(
    """
MATCH ()-[e:`$edge_type`]->()
  WITH e limit 1
MATCH (m)-[:`$edge_type`]->(n) WHERE id(m) == src(e) AND id(n) == dst(e)
RETURN "(:" + tags(m)[0] + ")-[:$edge_type]->(:" + tags(n)[0] + ")" AS rels
"""
)

RETRY_TIMES = 3


[docs]class NebulaGraph: """NebulaGraph用于图操作的包装器。 NebulaGraph继承了Neo4jGraph的方法,以使用户空间更加便捷。 *安全提示*:确保数据库连接使用的凭据范围狭窄,仅包括必要的权限。 如果未能这样做,可能会导致数据损坏或丢失,因为调用代码可能会尝试执行会导致删除、变异数据或读取敏感数据(如果数据库中存在此类数据)的命令。 防范这种负面结果的最佳方法是(视情况)限制授予此工具使用的凭据的权限。 有关更多信息,请参见https://python.langchain.com/docs/security。"""
[docs] def __init__( self, space: str, username: str = "root", password: str = "nebula", address: str = "127.0.0.1", port: int = 9669, session_pool_size: int = 30, ) -> None: """创建一个新的NebulaGraph包装实例。""" try: import nebula3 # noqa: F401 import pandas # noqa: F401 except ImportError: raise ImportError( "Please install NebulaGraph Python client and pandas first: " "`pip install nebula3-python pandas`" ) self.username = username self.password = password self.address = address self.port = port self.space = space self.session_pool_size = session_pool_size self.session_pool = self._get_session_pool() self.schema = "" # Set schema try: self.refresh_schema() except Exception as e: raise ValueError(f"Could not refresh schema. Error: {e}")
def _get_session_pool(self) -> Any: assert all( [self.username, self.password, self.address, self.port, self.space] ), ( "Please provide all of the following parameters: " "username, password, address, port, space" ) from nebula3.Config import SessionPoolConfig from nebula3.Exception import AuthFailedException, InValidHostname from nebula3.gclient.net.SessionPool import SessionPool config = SessionPoolConfig() config.max_size = self.session_pool_size try: session_pool = SessionPool( self.username, self.password, self.space, [(self.address, self.port)], ) except InValidHostname: raise ValueError( "Could not connect to NebulaGraph database. " "Please ensure that the address and port are correct" ) try: session_pool.init(config) except AuthFailedException: raise ValueError( "Could not connect to NebulaGraph database. " "Please ensure that the username and password are correct" ) except RuntimeError as e: raise ValueError(f"Error initializing session pool. Error: {e}") return session_pool def __del__(self) -> None: try: self.session_pool.close() except Exception as e: logger.warning(f"Could not close session pool. Error: {e}") @property def get_schema(self) -> str: """返回NebulaGraph数据库的模式""" return self.schema
[docs] def execute(self, query: str, params: Optional[dict] = None, retry: int = 0) -> Any: """查询NebulaGraph数据库。""" from nebula3.Exception import IOErrorException, NoValidSessionException from nebula3.fbthrift.transport.TTransport import TTransportException params = params or {} try: result = self.session_pool.execute_parameter(query, params) if not result.is_succeeded(): logger.warning( f"Error executing query to NebulaGraph. " f"Error: {result.error_msg()}\n" f"Query: {query} \n" ) return result except NoValidSessionException: logger.warning( f"No valid session found in session pool. " f"Please consider increasing the session pool size. " f"Current size: {self.session_pool_size}" ) raise ValueError( f"No valid session found in session pool. " f"Please consider increasing the session pool size. " f"Current size: {self.session_pool_size}" ) except RuntimeError as e: if retry < RETRY_TIMES: retry += 1 logger.warning( f"Error executing query to NebulaGraph. " f"Retrying ({retry}/{RETRY_TIMES})...\n" f"query: {query} \n" f"Error: {e}" ) return self.execute(query, params, retry) else: raise ValueError(f"Error executing query to NebulaGraph. Error: {e}") except (TTransportException, IOErrorException): # connection issue, try to recreate session pool if retry < RETRY_TIMES: retry += 1 logger.warning( f"Connection issue with NebulaGraph. " f"Retrying ({retry}/{RETRY_TIMES})...\n to recreate session pool" ) self.session_pool = self._get_session_pool() return self.execute(query, params, retry)
[docs] def refresh_schema(self) -> None: """ 刷新NebulaGraph的模式信息。 """ tags_schema, edge_types_schema, relationships = [], [], [] for tag in self.execute("SHOW TAGS").column_values("Name"): tag_name = tag.cast() tag_schema = {"tag": tag_name, "properties": []} r = self.execute(f"DESCRIBE TAG `{tag_name}`") props, types = r.column_values("Field"), r.column_values("Type") for i in range(r.row_size()): tag_schema["properties"].append((props[i].cast(), types[i].cast())) tags_schema.append(tag_schema) for edge_type in self.execute("SHOW EDGES").column_values("Name"): edge_type_name = edge_type.cast() edge_schema = {"edge": edge_type_name, "properties": []} r = self.execute(f"DESCRIBE EDGE `{edge_type_name}`") props, types = r.column_values("Field"), r.column_values("Type") for i in range(r.row_size()): edge_schema["properties"].append((props[i].cast(), types[i].cast())) edge_types_schema.append(edge_schema) # build relationships types r = self.execute( rel_query.substitute(edge_type=edge_type_name) ).column_values("rels") if len(r) > 0: relationships.append(r[0].cast()) self.schema = ( f"Node properties: {tags_schema}\n" f"Edge properties: {edge_types_schema}\n" f"Relationships: {relationships}\n" )
[docs] def query(self, query: str, retry: int = 0) -> Dict[str, Any]: result = self.execute(query, retry=retry) columns = result.keys() d: Dict[str, list] = {} for col_num in range(result.col_size()): col_name = columns[col_num] col_list = result.column_values(col_name) d[col_name] = [x.cast() for x in col_list] return d