from __future__ import annotations
import json
import re
from hashlib import md5
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Tuple, Union
from langchain_community.graphs.graph_document import GraphDocument
from langchain_community.graphs.graph_store import GraphStore
if TYPE_CHECKING:
import psycopg2.extras
[docs]class AGEQueryException(Exception):
"""针对AGE查询的异常。"""
def __init__(self, exception: Union[str, Dict]) -> None:
if isinstance(exception, dict):
self.message = exception["message"] if "message" in exception else "unknown"
self.details = exception["details"] if "details" in exception else "unknown"
else:
self.message = exception
self.details = "unknown"
def get_message(self) -> str:
return self.message
def get_details(self) -> Any:
return self.details
[docs]class AGEGraph(GraphStore):
"""Apache AGE包装器用于图操作。
参数:
graph_name(str):要连接或创建的图的名称
conf(Dict[str,Any]):直接传递给psycopg2.connect的pgsql连接配置
create(bool):如果为True且图不存在,则尝试创建它
*安全提示*:确保数据库连接使用的凭据范围狭窄,仅包括必要的权限。
如果未这样做,可能会导致数据损坏或丢失,因为调用代码可能尝试执行会导致删除、变异数据(如果适当提示)或读取敏感数据(如果数据库中存在此类数据)的命令。
防范这些负面结果的最佳方法是(根据需要)限制授予此工具使用的凭据的权限。
有关更多信息,请参见https://python.langchain.com/docs/security。"""
# python type mapping for providing readable types to LLM
types = {
"str": "STRING",
"float": "DOUBLE",
"int": "INTEGER",
"list": "LIST",
"dict": "MAP",
"bool": "BOOLEAN",
}
# precompiled regex for checking chars in graph labels
label_regex = re.compile("[^0-9a-zA-Z]+")
[docs] def __init__(
self, graph_name: str, conf: Dict[str, Any], create: bool = True
) -> None:
"""创建一个新的AGEGraph实例。"""
self.graph_name = graph_name
# check that psycopg2 is installed
try:
import psycopg2
except ImportError:
raise ImportError(
"Could not import psycopg2 python package. "
"Please install it with `pip install psycopg2`."
)
self.connection = psycopg2.connect(**conf)
with self._get_cursor() as curs:
# check if graph with name graph_name exists
graph_id_query = (
"""SELECT graphid FROM ag_catalog.ag_graph WHERE name = '{}'""".format(
graph_name
)
)
curs.execute(graph_id_query)
data = curs.fetchone()
# if graph doesn't exist and create is True, create it
if data is None:
if create:
create_statement = """
SELECT ag_catalog.create_graph('{}');
""".format(graph_name)
try:
curs.execute(create_statement)
self.connection.commit()
except psycopg2.Error as e:
raise AGEQueryException(
{
"message": "Could not create the graph",
"detail": str(e),
}
)
else:
raise Exception(
(
'Graph "{}" does not exist in the database '
+ 'and "create" is set to False'
).format(graph_name)
)
curs.execute(graph_id_query)
data = curs.fetchone()
# store graph id and refresh the schema
self.graphid = data.graphid
self.refresh_schema()
def _get_cursor(self) -> psycopg2.extras.NamedTupleCursor:
"""
获取游标,加载年龄扩展并设置搜索路径
"""
try:
import psycopg2.extras
except ImportError as e:
raise ImportError(
"Unable to import psycopg2, please install with "
"`pip install -U psycopg2`."
) from e
cursor = self.connection.cursor(cursor_factory=psycopg2.extras.NamedTupleCursor)
cursor.execute("""LOAD 'age';""")
cursor.execute("""SET search_path = ag_catalog, "$user", public;""")
return cursor
def _get_labels(self) -> Tuple[List[str], List[str]]:
"""获取图的所有标签(包括边和顶点)
通过直接查询图元数据表
返回
Tuple[List[str]]: 2个列表,第一个包含顶点标签,第二个包含边标签
"""
e_labels_records = self.query(
"""MATCH ()-[e]-() RETURN collect(distinct label(e)) as labels"""
)
e_labels = e_labels_records[0]["labels"] if e_labels_records else []
n_labels_records = self.query(
"""MATCH (n) RETURN collect(distinct label(n)) as labels"""
)
n_labels = n_labels_records[0]["labels"] if n_labels_records else []
return n_labels, e_labels
def _get_triples(self, e_labels: List[str]) -> List[Dict[str, str]]:
"""获取图中一组不同的关系类型(作为字典列表),用作llm的上下文。
参数:
e_labels(List[str]):要筛选的边标签列表
返回:
List[Dict[str, str]]:关系作为字典列表,格式为
"{'start':<from_label>, 'type':<edge_label>, 'end':<from_label>}"
"""
# age query to get distinct relationship types
try:
import psycopg2
except ImportError as e:
raise ImportError(
"Unable to import psycopg2, please install with "
"`pip install -U psycopg2`."
) from e
triple_query = """
SELECT * FROM ag_catalog.cypher('{graph_name}', $$
MATCH (a)-[e:`{e_label}`]->(b)
WITH a,e,b LIMIT 3000
RETURN DISTINCT labels(a) AS from, type(e) AS edge, labels(b) AS to
LIMIT 10
$$) AS (f agtype, edge agtype, t agtype);
"""
triple_schema = []
# iterate desired edge types and add distinct relationship types to result
with self._get_cursor() as curs:
for label in e_labels:
q = triple_query.format(graph_name=self.graph_name, e_label=label)
try:
curs.execute(q)
data = curs.fetchall()
for d in data:
# use json.loads to convert returned
# strings to python primitives
triple_schema.append(
{
"start": json.loads(d.f)[0],
"type": json.loads(d.edge),
"end": json.loads(d.t)[0],
}
)
except psycopg2.Error as e:
raise AGEQueryException(
{
"message": "Error fetching triples",
"detail": str(e),
}
)
return triple_schema
def _get_triples_str(self, e_labels: List[str]) -> List[str]:
"""获取图中一组不同的关系类型(作为字符串列表),
以供llm使用上下文。
参数:
e_labels(List[str]):要筛选的边标签列表
返回:
List[str]:关系作为字符串列表,格式为
"(:`<from_label>`)-[:`<edge_label>`]->(:`<to_label>`)"
"""
triples = self._get_triples(e_labels)
return self._format_triples(triples)
@staticmethod
def _format_triples(triples: List[Dict[str, str]]) -> List[str]:
"""将关系列表从字典转换为格式化字符串,以便llm更易读
参数:
triples(List[Dict[str,str]]):形式为关系列表
{'start':<from_label>,'type':<edge_label>,'end':<from_label>}
返回:
List[str]:形式为关系列表
"(:`<from_label>`)-[:`<edge_label>`]->(:`<to_label>`)"
"""
triple_template = "(:`{start}`)-[:`{type}`]->(:`{end}`)"
triple_schema = [triple_template.format(**triple) for triple in triples]
return triple_schema
def _get_node_properties(self, n_labels: List[str]) -> List[Dict[str, Any]]:
"""获取可用节点属性列表,按节点标签过滤以供llm使用
参数:
n_labels(List[str]):要过滤的节点标签列表
返回:
List[Dict[str, Any]]:节点标签及其对应属性的列表,格式如下:
"{
'labels': <node_label>,
'properties': [
{
'property': <property_name>,
'type': <property_type>
},...
]
}"
"""
try:
import psycopg2
except ImportError as e:
raise ImportError(
"Unable to import psycopg2, please install with "
"`pip install -U psycopg2`."
) from e
# cypher query to fetch properties of a given label
node_properties_query = """
SELECT * FROM ag_catalog.cypher('{graph_name}', $$
MATCH (a:`{n_label}`)
RETURN properties(a) AS props
LIMIT 100
$$) AS (props agtype);
"""
node_properties = []
with self._get_cursor() as curs:
for label in n_labels:
q = node_properties_query.format(
graph_name=self.graph_name, n_label=label
)
try:
curs.execute(q)
except psycopg2.Error as e:
raise AGEQueryException(
{
"message": "Error fetching node properties",
"detail": str(e),
}
)
data = curs.fetchall()
# build a set of distinct properties
s = set({})
for d in data:
# use json.loads to convert to python
# primitive and get readable type
for k, v in json.loads(d.props).items():
s.add((k, self.types[type(v).__name__]))
np = {
"properties": [{"property": k, "type": v} for k, v in s],
"labels": label,
}
node_properties.append(np)
return node_properties
def _get_edge_properties(self, e_labels: List[str]) -> List[Dict[str, Any]]:
"""获取可用边属性列表,按边标签过滤以供llm使用
参数:
e_labels(List[str]):要过滤的边标签列表
返回:
List[Dict[str, Any]]:边标签列表及其对应属性的列表,格式如下:
"{
'labels': <edge_label>,
'properties': [
{
'property': <property_name>,
'type': <property_type>
},...
]
}"
"""
try:
import psycopg2
except ImportError as e:
raise ImportError(
"Unable to import psycopg2, please install with "
"`pip install -U psycopg2`."
) from e
# cypher query to fetch properties of a given label
edge_properties_query = """
SELECT * FROM ag_catalog.cypher('{graph_name}', $$
MATCH ()-[e:`{e_label}`]->()
RETURN properties(e) AS props
LIMIT 100
$$) AS (props agtype);
"""
edge_properties = []
with self._get_cursor() as curs:
for label in e_labels:
q = edge_properties_query.format(
graph_name=self.graph_name, e_label=label
)
try:
curs.execute(q)
except psycopg2.Error as e:
raise AGEQueryException(
{
"message": "Error fetching edge properties",
"detail": str(e),
}
)
data = curs.fetchall()
# build a set of distinct properties
s = set({})
for d in data:
# use json.loads to convert to python
# primitive and get readable type
for k, v in json.loads(d.props).items():
s.add((k, self.types[type(v).__name__]))
np = {
"properties": [{"property": k, "type": v} for k, v in s],
"type": label,
}
edge_properties.append(np)
return edge_properties
[docs] def refresh_schema(self) -> None:
"""刷新图架构信息,更新可用的标签、关系和属性。
"""
# fetch graph schema information
n_labels, e_labels = self._get_labels()
triple_schema = self._get_triples(e_labels)
node_properties = self._get_node_properties(n_labels)
edge_properties = self._get_edge_properties(e_labels)
# update the formatted string representation
self.schema = f"""
Node properties are the following:
{node_properties}
Relationship properties are the following:
{edge_properties}
The relationships are the following:
{self._format_triples(triple_schema)}
"""
# update the dictionary representation
self.structured_schema = {
"node_props": {el["labels"]: el["properties"] for el in node_properties},
"rel_props": {el["type"]: el["properties"] for el in edge_properties},
"relationships": triple_schema,
"metadata": {},
}
@property
def get_schema(self) -> str:
"""返回图的模式"""
return self.schema
@property
def get_structured_schema(self) -> Dict[str, Any]:
"""返回图的结构化模式"""
return self.structured_schema
@staticmethod
def _get_col_name(field: str, idx: int) -> str:
"""将密码返回字段转换为pgsql选择字段
如果可能,保留密码列名,如果必要,创建一个通用名称
参数:
field(str):要格式化为pgsql的cypher查询返回字段
idx(int):字段在返回语句中的位置
返回:
str:用于pgsql选择语句的字段
"""
# remove white space
field = field.strip()
# if an alias is provided for the field, use it
if " as " in field:
return field.split(" as ")[-1].strip()
# if the return value is an unnamed primitive, give it a generic name
elif field.isnumeric() or field in ("true", "false", "null"):
return f"column_{idx}"
# otherwise return the value stripping out some common special chars
else:
return field.replace("(", "_").replace(")", "")
@staticmethod
def _wrap_query(query: str, graph_name: str) -> str:
"""将一个密码查询转换为与Apache Age兼容的SQL查询,方法是将密码查询包装在ag_catalog.cypher中,将结果转换为agtype并构建一个select语句
参数:
query(str):有效的密码查询
graph_name(str):要查询的图的名称
返回:
str:等效的pgsql查询
"""
# pgsql template
template = """SELECT {projection} FROM ag_catalog.cypher('{graph_name}', $$
{query}
$$) AS ({fields});"""
# if there are any returned fields they must be added to the pgsql query
if "return" in query.lower():
# parse return statement to identify returned fields
fields = (
query.lower()
.split("return")[-1]
.split("distinct")[-1]
.split("order by")[0]
.split("skip")[0]
.split("limit")[0]
.split(",")
)
# raise exception if RETURN * is found as we can't resolve the fields
if "*" in [x.strip() for x in fields]:
raise ValueError(
"AGE graph does not support 'RETURN *'"
+ " statements in Cypher queries"
)
# get pgsql formatted field names
fields = [
AGEGraph._get_col_name(field, idx) for idx, field in enumerate(fields)
]
# build resulting pgsql relation
fields_str = ", ".join(
[field.split(".")[-1] + " agtype" for field in fields]
)
# if no return statement we still need to return a single field of type agtype
else:
fields_str = "a agtype"
select_str = "*"
return template.format(
graph_name=graph_name,
query=query,
fields=fields_str,
projection=select_str,
)
@staticmethod
def _record_to_dict(record: NamedTuple) -> Dict[str, Any]:
"""将从年龄查询返回的记录转换为字典
参数:
record():来自年龄查询结果的记录
返回:
Dict[str, Any]:记录的字典表示,其中
字典键是字段名,值是
转换为python类型的值
"""
# result holder
d = {}
# prebuild a mapping of vertex_id to vertex mappings to be used
# later to build edges
vertices = {}
for k in record._fields:
v = getattr(record, k)
# agtype comes back '{key: value}::type' which must be parsed
if isinstance(v, str) and "::" in v:
dtype = v.split("::")[-1]
v = v.split("::")[0]
if dtype == "vertex":
vertex = json.loads(v)
vertices[vertex["id"]] = vertex.get("properties")
# iterate returned fields and parse appropriately
for k in record._fields:
v = getattr(record, k)
if isinstance(v, str) and "::" in v:
dtype = v.split("::")[-1]
v = v.split("::")[0]
else:
dtype = ""
if dtype == "vertex":
d[k] = json.loads(v).get("properties")
# convert edge from id-label->id by replacing id with node information
# we only do this if the vertex was also returned in the query
# this is an attempt to be consistent with neo4j implementation
elif dtype == "edge":
edge = json.loads(v)
d[k] = (
vertices.get(edge["start_id"], {}),
edge["label"],
vertices.get(edge["end_id"], {}),
)
else:
d[k] = json.loads(v) if isinstance(v, str) else v
return d
[docs] def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]:
"""通过采用一个cypher查询来查询图形,将其转换为一个
兼容的查询,执行它并转换结果
参数:
query (str): 要执行的cypher查询
params (dict): 查询的参数(在此实现中未使用)
返回:
List[Dict[str, Any]]: 包含结果集的字典列表
"""
try:
import psycopg2
except ImportError as e:
raise ImportError(
"Unable to import psycopg2, please install with "
"`pip install -U psycopg2`."
) from e
# convert cypher query to pgsql/age query
wrapped_query = self._wrap_query(query, self.graph_name)
# execute the query, rolling back on an error
with self._get_cursor() as curs:
try:
curs.execute(wrapped_query)
self.connection.commit()
except psycopg2.Error as e:
self.connection.rollback()
raise AGEQueryException(
{
"message": "Error executing graph query: {}".format(query),
"detail": str(e),
}
)
data = curs.fetchall()
if data is None:
result = []
# convert to dictionaries
else:
result = [self._record_to_dict(d) for d in data]
return result
@staticmethod
def _format_properties(
properties: Dict[str, Any], id: Union[str, None] = None
) -> str:
"""将属性字典转换为字符串表示形式,可用于cypher查询的插入/合并语句。
参数:
properties(Dict[str,str]):包含节点/边属性的字典
id(Union[str, None]):节点的id,如果不存在则为None
返回:
str:属性字典作为格式正确的字符串
"""
props = []
# wrap property key in backticks to escape
for k, v in properties.items():
prop = f"`{k}`: {json.dumps(v)}"
props.append(prop)
if id is not None and "id" not in properties:
props.append(
f"id: {json.dumps(id)}" if isinstance(id, str) else f"id: {id}"
)
return "{" + ", ".join(props) + "}"
[docs] @staticmethod
def clean_graph_labels(label: str) -> str:
"""删除标签中的任何不允许字符,并替换为'_'。
参数:
label (str): 原始标签
返回:
str: 标签的清理版本
"""
return re.sub(AGEGraph.label_regex, "_", label)
[docs] def add_graph_documents(
self, graph_documents: List[GraphDocument], include_source: bool = False
) -> None:
"""将图文档列表插入图中
参数:
graph_documents(List[GraphDocument]):要插入的文档列表
include_source(bool):如果为True,则为源添加节点,并使用MENTIONS边将它们与提到的实体连接起来
返回:
无
"""
# query for inserting nodes
node_insert_query = (
"""
MERGE (n:`{label}` {properties})
"""
if not include_source
else """
MERGE (n:`{label}` {properties})
MERGE (d:Document {d_properties})
MERGE (d)-[:MENTIONS]->(n)
"""
)
# query for inserting edges
edge_insert_query = """
MERGE (from:`{f_label}` {f_properties})
MERGE (to:`{t_label}` {t_properties})
MERGE (from)-[:`{r_label}` {r_properties}]->(to)
"""
# iterate docs and insert them
for doc in graph_documents:
# if we are adding sources, create an id for the source
if include_source:
if not doc.source.metadata.get("id"):
doc.source.metadata["id"] = md5(
doc.source.page_content.encode("utf-8")
).hexdigest()
# insert entity nodes
for node in doc.nodes:
node.properties["id"] = node.id
if include_source:
query = node_insert_query.format(
label=node.type,
properties=self._format_properties(node.properties),
d_properties=self._format_properties(doc.source.metadata),
)
else:
query = node_insert_query.format(
label=AGEGraph.clean_graph_labels(node.type),
properties=self._format_properties(node.properties),
)
self.query(query)
# insert relationships
for edge in doc.relationships:
edge.source.properties["id"] = edge.source.id
edge.target.properties["id"] = edge.target.id
inputs = {
"f_label": AGEGraph.clean_graph_labels(edge.source.type),
"f_properties": self._format_properties(edge.source.properties),
"t_label": AGEGraph.clean_graph_labels(edge.target.type),
"t_properties": self._format_properties(edge.target.properties),
"r_label": AGEGraph.clean_graph_labels(edge.type).upper(),
"r_properties": self._format_properties(edge.properties),
}
query = edge_insert_query.format(**inputs)
self.query(query)