Source code for langchain_community.graphs.age_graph

from __future__ import annotations

import json
import re
from hashlib import md5
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Tuple, Union

from langchain_community.graphs.graph_document import GraphDocument
from langchain_community.graphs.graph_store import GraphStore

if TYPE_CHECKING:
    import psycopg2.extras


[docs]class AGEQueryException(Exception): """针对AGE查询的异常。""" def __init__(self, exception: Union[str, Dict]) -> None: if isinstance(exception, dict): self.message = exception["message"] if "message" in exception else "unknown" self.details = exception["details"] if "details" in exception else "unknown" else: self.message = exception self.details = "unknown" def get_message(self) -> str: return self.message def get_details(self) -> Any: return self.details
[docs]class AGEGraph(GraphStore): """Apache AGE包装器用于图操作。 参数: graph_name(str):要连接或创建的图的名称 conf(Dict[str,Any]):直接传递给psycopg2.connect的pgsql连接配置 create(bool):如果为True且图不存在,则尝试创建它 *安全提示*:确保数据库连接使用的凭据范围狭窄,仅包括必要的权限。 如果未这样做,可能会导致数据损坏或丢失,因为调用代码可能尝试执行会导致删除、变异数据(如果适当提示)或读取敏感数据(如果数据库中存在此类数据)的命令。 防范这些负面结果的最佳方法是(根据需要)限制授予此工具使用的凭据的权限。 有关更多信息,请参见https://python.langchain.com/docs/security。""" # python type mapping for providing readable types to LLM types = { "str": "STRING", "float": "DOUBLE", "int": "INTEGER", "list": "LIST", "dict": "MAP", "bool": "BOOLEAN", } # precompiled regex for checking chars in graph labels label_regex = re.compile("[^0-9a-zA-Z]+")
[docs] def __init__( self, graph_name: str, conf: Dict[str, Any], create: bool = True ) -> None: """创建一个新的AGEGraph实例。""" self.graph_name = graph_name # check that psycopg2 is installed try: import psycopg2 except ImportError: raise ImportError( "Could not import psycopg2 python package. " "Please install it with `pip install psycopg2`." ) self.connection = psycopg2.connect(**conf) with self._get_cursor() as curs: # check if graph with name graph_name exists graph_id_query = ( """SELECT graphid FROM ag_catalog.ag_graph WHERE name = '{}'""".format( graph_name ) ) curs.execute(graph_id_query) data = curs.fetchone() # if graph doesn't exist and create is True, create it if data is None: if create: create_statement = """ SELECT ag_catalog.create_graph('{}'); """.format(graph_name) try: curs.execute(create_statement) self.connection.commit() except psycopg2.Error as e: raise AGEQueryException( { "message": "Could not create the graph", "detail": str(e), } ) else: raise Exception( ( 'Graph "{}" does not exist in the database ' + 'and "create" is set to False' ).format(graph_name) ) curs.execute(graph_id_query) data = curs.fetchone() # store graph id and refresh the schema self.graphid = data.graphid self.refresh_schema()
def _get_cursor(self) -> psycopg2.extras.NamedTupleCursor: """ 获取游标,加载年龄扩展并设置搜索路径 """ try: import psycopg2.extras except ImportError as e: raise ImportError( "Unable to import psycopg2, please install with " "`pip install -U psycopg2`." ) from e cursor = self.connection.cursor(cursor_factory=psycopg2.extras.NamedTupleCursor) cursor.execute("""LOAD 'age';""") cursor.execute("""SET search_path = ag_catalog, "$user", public;""") return cursor def _get_labels(self) -> Tuple[List[str], List[str]]: """获取图的所有标签(包括边和顶点) 通过直接查询图元数据表 返回 Tuple[List[str]]: 2个列表,第一个包含顶点标签,第二个包含边标签 """ e_labels_records = self.query( """MATCH ()-[e]-() RETURN collect(distinct label(e)) as labels""" ) e_labels = e_labels_records[0]["labels"] if e_labels_records else [] n_labels_records = self.query( """MATCH (n) RETURN collect(distinct label(n)) as labels""" ) n_labels = n_labels_records[0]["labels"] if n_labels_records else [] return n_labels, e_labels def _get_triples(self, e_labels: List[str]) -> List[Dict[str, str]]: """获取图中一组不同的关系类型(作为字典列表),用作llm的上下文。 参数: e_labels(List[str]):要筛选的边标签列表 返回: List[Dict[str, str]]:关系作为字典列表,格式为 "{'start':<from_label>, 'type':<edge_label>, 'end':<from_label>}" """ # age query to get distinct relationship types try: import psycopg2 except ImportError as e: raise ImportError( "Unable to import psycopg2, please install with " "`pip install -U psycopg2`." ) from e triple_query = """ SELECT * FROM ag_catalog.cypher('{graph_name}', $$ MATCH (a)-[e:`{e_label}`]->(b) WITH a,e,b LIMIT 3000 RETURN DISTINCT labels(a) AS from, type(e) AS edge, labels(b) AS to LIMIT 10 $$) AS (f agtype, edge agtype, t agtype); """ triple_schema = [] # iterate desired edge types and add distinct relationship types to result with self._get_cursor() as curs: for label in e_labels: q = triple_query.format(graph_name=self.graph_name, e_label=label) try: curs.execute(q) data = curs.fetchall() for d in data: # use json.loads to convert returned # strings to python primitives triple_schema.append( { "start": json.loads(d.f)[0], "type": json.loads(d.edge), "end": json.loads(d.t)[0], } ) except psycopg2.Error as e: raise AGEQueryException( { "message": "Error fetching triples", "detail": str(e), } ) return triple_schema def _get_triples_str(self, e_labels: List[str]) -> List[str]: """获取图中一组不同的关系类型(作为字符串列表), 以供llm使用上下文。 参数: e_labels(List[str]):要筛选的边标签列表 返回: List[str]:关系作为字符串列表,格式为 "(:`<from_label>`)-[:`<edge_label>`]->(:`<to_label>`)" """ triples = self._get_triples(e_labels) return self._format_triples(triples) @staticmethod def _format_triples(triples: List[Dict[str, str]]) -> List[str]: """将关系列表从字典转换为格式化字符串,以便llm更易读 参数: triples(List[Dict[str,str]]):形式为关系列表 {'start':<from_label>,'type':<edge_label>,'end':<from_label>} 返回: List[str]:形式为关系列表 "(:`<from_label>`)-[:`<edge_label>`]->(:`<to_label>`)" """ triple_template = "(:`{start}`)-[:`{type}`]->(:`{end}`)" triple_schema = [triple_template.format(**triple) for triple in triples] return triple_schema def _get_node_properties(self, n_labels: List[str]) -> List[Dict[str, Any]]: """获取可用节点属性列表,按节点标签过滤以供llm使用 参数: n_labels(List[str]):要过滤的节点标签列表 返回: List[Dict[str, Any]]:节点标签及其对应属性的列表,格式如下: "{ 'labels': <node_label>, 'properties': [ { 'property': <property_name>, 'type': <property_type> },... ] }" """ try: import psycopg2 except ImportError as e: raise ImportError( "Unable to import psycopg2, please install with " "`pip install -U psycopg2`." ) from e # cypher query to fetch properties of a given label node_properties_query = """ SELECT * FROM ag_catalog.cypher('{graph_name}', $$ MATCH (a:`{n_label}`) RETURN properties(a) AS props LIMIT 100 $$) AS (props agtype); """ node_properties = [] with self._get_cursor() as curs: for label in n_labels: q = node_properties_query.format( graph_name=self.graph_name, n_label=label ) try: curs.execute(q) except psycopg2.Error as e: raise AGEQueryException( { "message": "Error fetching node properties", "detail": str(e), } ) data = curs.fetchall() # build a set of distinct properties s = set({}) for d in data: # use json.loads to convert to python # primitive and get readable type for k, v in json.loads(d.props).items(): s.add((k, self.types[type(v).__name__])) np = { "properties": [{"property": k, "type": v} for k, v in s], "labels": label, } node_properties.append(np) return node_properties def _get_edge_properties(self, e_labels: List[str]) -> List[Dict[str, Any]]: """获取可用边属性列表,按边标签过滤以供llm使用 参数: e_labels(List[str]):要过滤的边标签列表 返回: List[Dict[str, Any]]:边标签列表及其对应属性的列表,格式如下: "{ 'labels': <edge_label>, 'properties': [ { 'property': <property_name>, 'type': <property_type> },... ] }" """ try: import psycopg2 except ImportError as e: raise ImportError( "Unable to import psycopg2, please install with " "`pip install -U psycopg2`." ) from e # cypher query to fetch properties of a given label edge_properties_query = """ SELECT * FROM ag_catalog.cypher('{graph_name}', $$ MATCH ()-[e:`{e_label}`]->() RETURN properties(e) AS props LIMIT 100 $$) AS (props agtype); """ edge_properties = [] with self._get_cursor() as curs: for label in e_labels: q = edge_properties_query.format( graph_name=self.graph_name, e_label=label ) try: curs.execute(q) except psycopg2.Error as e: raise AGEQueryException( { "message": "Error fetching edge properties", "detail": str(e), } ) data = curs.fetchall() # build a set of distinct properties s = set({}) for d in data: # use json.loads to convert to python # primitive and get readable type for k, v in json.loads(d.props).items(): s.add((k, self.types[type(v).__name__])) np = { "properties": [{"property": k, "type": v} for k, v in s], "type": label, } edge_properties.append(np) return edge_properties
[docs] def refresh_schema(self) -> None: """刷新图架构信息,更新可用的标签、关系和属性。 """ # fetch graph schema information n_labels, e_labels = self._get_labels() triple_schema = self._get_triples(e_labels) node_properties = self._get_node_properties(n_labels) edge_properties = self._get_edge_properties(e_labels) # update the formatted string representation self.schema = f""" Node properties are the following: {node_properties} Relationship properties are the following: {edge_properties} The relationships are the following: {self._format_triples(triple_schema)} """ # update the dictionary representation self.structured_schema = { "node_props": {el["labels"]: el["properties"] for el in node_properties}, "rel_props": {el["type"]: el["properties"] for el in edge_properties}, "relationships": triple_schema, "metadata": {}, }
@property def get_schema(self) -> str: """返回图的模式""" return self.schema @property def get_structured_schema(self) -> Dict[str, Any]: """返回图的结构化模式""" return self.structured_schema @staticmethod def _get_col_name(field: str, idx: int) -> str: """将密码返回字段转换为pgsql选择字段 如果可能,保留密码列名,如果必要,创建一个通用名称 参数: field(str):要格式化为pgsql的cypher查询返回字段 idx(int):字段在返回语句中的位置 返回: str:用于pgsql选择语句的字段 """ # remove white space field = field.strip() # if an alias is provided for the field, use it if " as " in field: return field.split(" as ")[-1].strip() # if the return value is an unnamed primitive, give it a generic name elif field.isnumeric() or field in ("true", "false", "null"): return f"column_{idx}" # otherwise return the value stripping out some common special chars else: return field.replace("(", "_").replace(")", "") @staticmethod def _wrap_query(query: str, graph_name: str) -> str: """将一个密码查询转换为与Apache Age兼容的SQL查询,方法是将密码查询包装在ag_catalog.cypher中,将结果转换为agtype并构建一个select语句 参数: query(str):有效的密码查询 graph_name(str):要查询的图的名称 返回: str:等效的pgsql查询 """ # pgsql template template = """SELECT {projection} FROM ag_catalog.cypher('{graph_name}', $$ {query} $$) AS ({fields});""" # if there are any returned fields they must be added to the pgsql query if "return" in query.lower(): # parse return statement to identify returned fields fields = ( query.lower() .split("return")[-1] .split("distinct")[-1] .split("order by")[0] .split("skip")[0] .split("limit")[0] .split(",") ) # raise exception if RETURN * is found as we can't resolve the fields if "*" in [x.strip() for x in fields]: raise ValueError( "AGE graph does not support 'RETURN *'" + " statements in Cypher queries" ) # get pgsql formatted field names fields = [ AGEGraph._get_col_name(field, idx) for idx, field in enumerate(fields) ] # build resulting pgsql relation fields_str = ", ".join( [field.split(".")[-1] + " agtype" for field in fields] ) # if no return statement we still need to return a single field of type agtype else: fields_str = "a agtype" select_str = "*" return template.format( graph_name=graph_name, query=query, fields=fields_str, projection=select_str, ) @staticmethod def _record_to_dict(record: NamedTuple) -> Dict[str, Any]: """将从年龄查询返回的记录转换为字典 参数: record():来自年龄查询结果的记录 返回: Dict[str, Any]:记录的字典表示,其中 字典键是字段名,值是 转换为python类型的值 """ # result holder d = {} # prebuild a mapping of vertex_id to vertex mappings to be used # later to build edges vertices = {} for k in record._fields: v = getattr(record, k) # agtype comes back '{key: value}::type' which must be parsed if isinstance(v, str) and "::" in v: dtype = v.split("::")[-1] v = v.split("::")[0] if dtype == "vertex": vertex = json.loads(v) vertices[vertex["id"]] = vertex.get("properties") # iterate returned fields and parse appropriately for k in record._fields: v = getattr(record, k) if isinstance(v, str) and "::" in v: dtype = v.split("::")[-1] v = v.split("::")[0] else: dtype = "" if dtype == "vertex": d[k] = json.loads(v).get("properties") # convert edge from id-label->id by replacing id with node information # we only do this if the vertex was also returned in the query # this is an attempt to be consistent with neo4j implementation elif dtype == "edge": edge = json.loads(v) d[k] = ( vertices.get(edge["start_id"], {}), edge["label"], vertices.get(edge["end_id"], {}), ) else: d[k] = json.loads(v) if isinstance(v, str) else v return d
[docs] def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]: """通过采用一个cypher查询来查询图形,将其转换为一个 兼容的查询,执行它并转换结果 参数: query (str): 要执行的cypher查询 params (dict): 查询的参数(在此实现中未使用) 返回: List[Dict[str, Any]]: 包含结果集的字典列表 """ try: import psycopg2 except ImportError as e: raise ImportError( "Unable to import psycopg2, please install with " "`pip install -U psycopg2`." ) from e # convert cypher query to pgsql/age query wrapped_query = self._wrap_query(query, self.graph_name) # execute the query, rolling back on an error with self._get_cursor() as curs: try: curs.execute(wrapped_query) self.connection.commit() except psycopg2.Error as e: self.connection.rollback() raise AGEQueryException( { "message": "Error executing graph query: {}".format(query), "detail": str(e), } ) data = curs.fetchall() if data is None: result = [] # convert to dictionaries else: result = [self._record_to_dict(d) for d in data] return result
@staticmethod def _format_properties( properties: Dict[str, Any], id: Union[str, None] = None ) -> str: """将属性字典转换为字符串表示形式,可用于cypher查询的插入/合并语句。 参数: properties(Dict[str,str]):包含节点/边属性的字典 id(Union[str, None]):节点的id,如果不存在则为None 返回: str:属性字典作为格式正确的字符串 """ props = [] # wrap property key in backticks to escape for k, v in properties.items(): prop = f"`{k}`: {json.dumps(v)}" props.append(prop) if id is not None and "id" not in properties: props.append( f"id: {json.dumps(id)}" if isinstance(id, str) else f"id: {id}" ) return "{" + ", ".join(props) + "}"
[docs] @staticmethod def clean_graph_labels(label: str) -> str: """删除标签中的任何不允许字符,并替换为'_'。 参数: label (str): 原始标签 返回: str: 标签的清理版本 """ return re.sub(AGEGraph.label_regex, "_", label)
[docs] def add_graph_documents( self, graph_documents: List[GraphDocument], include_source: bool = False ) -> None: """将图文档列表插入图中 参数: graph_documents(List[GraphDocument]):要插入的文档列表 include_source(bool):如果为True,则为源添加节点,并使用MENTIONS边将它们与提到的实体连接起来 返回: """ # query for inserting nodes node_insert_query = ( """ MERGE (n:`{label}` {properties}) """ if not include_source else """ MERGE (n:`{label}` {properties}) MERGE (d:Document {d_properties}) MERGE (d)-[:MENTIONS]->(n) """ ) # query for inserting edges edge_insert_query = """ MERGE (from:`{f_label}` {f_properties}) MERGE (to:`{t_label}` {t_properties}) MERGE (from)-[:`{r_label}` {r_properties}]->(to) """ # iterate docs and insert them for doc in graph_documents: # if we are adding sources, create an id for the source if include_source: if not doc.source.metadata.get("id"): doc.source.metadata["id"] = md5( doc.source.page_content.encode("utf-8") ).hexdigest() # insert entity nodes for node in doc.nodes: node.properties["id"] = node.id if include_source: query = node_insert_query.format( label=node.type, properties=self._format_properties(node.properties), d_properties=self._format_properties(doc.source.metadata), ) else: query = node_insert_query.format( label=AGEGraph.clean_graph_labels(node.type), properties=self._format_properties(node.properties), ) self.query(query) # insert relationships for edge in doc.relationships: edge.source.properties["id"] = edge.source.id edge.target.properties["id"] = edge.target.id inputs = { "f_label": AGEGraph.clean_graph_labels(edge.source.type), "f_properties": self._format_properties(edge.source.properties), "t_label": AGEGraph.clean_graph_labels(edge.target.type), "t_properties": self._format_properties(edge.target.properties), "r_label": AGEGraph.clean_graph_labels(edge.type).upper(), "r_properties": self._format_properties(edge.properties), } query = edge_insert_query.format(**inputs) self.query(query)