"""Networkx图操作的包装器。"""
from __future__ import annotations
from typing import Any, List, NamedTuple, Optional, Tuple
KG_TRIPLE_DELIMITER = "<|>"
[docs]class KnowledgeTriple(NamedTuple):
"""图中的知识三元组。"""
subject: str
predicate: str
object_: str
[docs] @classmethod
def from_string(cls, triple_string: str) -> "KnowledgeTriple":
"""从字符串中创建一个KnowledgeTriple。"""
subject, predicate, object_ = triple_string.strip().split(", ")
subject = subject[1:]
object_ = object_[:-1]
return cls(subject, predicate, object_)
[docs]def parse_triples(knowledge_str: str) -> List[KnowledgeTriple]:
"""从知识字符串中解析知识三元组。"""
knowledge_str = knowledge_str.strip()
if not knowledge_str or knowledge_str == "NONE":
return []
triple_strs = knowledge_str.split(KG_TRIPLE_DELIMITER)
results = []
for triple_str in triple_strs:
try:
kg_triple = KnowledgeTriple.from_string(triple_str)
except ValueError:
continue
results.append(kg_triple)
return results
[docs]def get_entities(entity_str: str) -> List[str]:
"""从实体字符串中提取实体。"""
if entity_str.strip() == "NONE":
return []
else:
return [w.strip() for w in entity_str.split(",")]
[docs]class NetworkxEntityGraph:
"""用于实体图操作的Networkx包装器。
*安全提示*: 确保数据库连接使用的凭据范围狭窄,仅包括必要的权限。
如果未能这样做,可能会导致数据损坏或丢失,因为调用代码可能会尝试执行会导致删除、变异数据(如果适当提示)或读取敏感数据(如果数据库中存在此类数据)的命令。
防范这些负面结果的最佳方法是(根据需要)限制授予此工具使用的凭据的权限。
有关更多信息,请参见 https://python.langchain.com/docs/security。"""
[docs] def __init__(self, graph: Optional[Any] = None) -> None:
"""创建一个新的图。"""
try:
import networkx as nx
except ImportError:
raise ImportError(
"Could not import networkx python package. "
"Please install it with `pip install networkx`."
)
if graph is not None:
if not isinstance(graph, nx.DiGraph):
raise ValueError("Passed in graph is not of correct shape")
self._graph = graph
else:
self._graph = nx.DiGraph()
[docs] @classmethod
def from_gml(cls, gml_path: str) -> NetworkxEntityGraph:
try:
import networkx as nx
except ImportError:
raise ImportError(
"Could not import networkx python package. "
"Please install it with `pip install networkx`."
)
graph = nx.read_gml(gml_path)
return cls(graph)
[docs] def add_triple(self, knowledge_triple: KnowledgeTriple) -> None:
"""向图中添加一个三元组。"""
# Creates nodes if they don't exist
# Overwrites existing edges
if not self._graph.has_node(knowledge_triple.subject):
self._graph.add_node(knowledge_triple.subject)
if not self._graph.has_node(knowledge_triple.object_):
self._graph.add_node(knowledge_triple.object_)
self._graph.add_edge(
knowledge_triple.subject,
knowledge_triple.object_,
relation=knowledge_triple.predicate,
)
[docs] def delete_triple(self, knowledge_triple: KnowledgeTriple) -> None:
"""从图中删除一个三元组。"""
if self._graph.has_edge(knowledge_triple.subject, knowledge_triple.object_):
self._graph.remove_edge(knowledge_triple.subject, knowledge_triple.object_)
[docs] def get_triples(self) -> List[Tuple[str, str, str]]:
"""获取图中的所有三元组。"""
return [(u, v, d["relation"]) for u, v, d in self._graph.edges(data=True)]
[docs] def get_entity_knowledge(self, entity: str, depth: int = 1) -> List[str]:
"""获取有关实体的信息。"""
import networkx as nx
# TODO: Have more information-specific retrieval methods
if not self._graph.has_node(entity):
return []
results = []
for src, sink in nx.dfs_edges(self._graph, entity, depth_limit=depth):
relation = self._graph[src][sink]["relation"]
results.append(f"{src} {relation} {sink}")
return results
[docs] def write_to_gml(self, path: str) -> None:
import networkx as nx
nx.write_gml(self._graph, path)
[docs] def clear(self) -> None:
"""清除图表。"""
self._graph.clear()
[docs] def clear_edges(self) -> None:
"""清除图的边缘。"""
self._graph.clear_edges()
[docs] def add_node(self, node: str) -> None:
"""在图中添加节点。"""
self._graph.add_node(node)
[docs] def remove_node(self, node: str) -> None:
"""从图中删除节点。"""
if self._graph.has_node(node):
self._graph.remove_node(node)
[docs] def has_node(self, node: str) -> bool:
"""如果图中有给定的节点,则返回。"""
return self._graph.has_node(node)
[docs] def remove_edge(self, source_node: str, destination_node: str) -> None:
"""从图中删除边。"""
self._graph.remove_edge(source_node, destination_node)
[docs] def has_edge(self, source_node: str, destination_node: str) -> bool:
"""如果图中存在给定节点之间的边,则返回True。"""
if self._graph.has_node(source_node) and self._graph.has_node(destination_node):
return self._graph.has_edge(source_node, destination_node)
else:
return False
[docs] def get_neighbors(self, node: str) -> List[str]:
"""返回给定节点的邻居节点。"""
return self._graph.neighbors(node)
[docs] def get_number_of_nodes(self) -> int:
"""获取图中节点的数量。"""
return self._graph.number_of_nodes()
[docs] def get_topological_sort(self) -> List[str]:
"""按因果依赖关系排序的图中实体名称列表。"""
import networkx as nx
return list(nx.topological_sort(self._graph))
[docs] def draw_graphviz(self, **kwargs: Any) -> None:
"""提供更好的绘图
在jupyter notebook中的使用:
>>> from IPython.display import SVG
>>> self.draw_graphviz_svg(layout="dot", filename="web.svg")
>>> SVG('web.svg')
"""
from networkx.drawing.nx_agraph import to_agraph
try:
import pygraphviz # noqa: F401
except ImportError as e:
if e.name == "_graphviz":
"""
>>> e.msg # pygraphviz throws this error
ImportError: libcgraph.so.6: cannot open shared object file
"""
raise ImportError(
"Could not import graphviz debian package. "
"Please install it with:"
"`sudo apt-get update`"
"`sudo apt-get install graphviz graphviz-dev`"
)
else:
raise ImportError(
"Could not import pygraphviz python package. "
"Please install it with:"
"`pip install pygraphviz`."
)
graph = to_agraph(self._graph) # --> pygraphviz.agraph.AGraph
# pygraphviz.github.io/documentation/stable/tutorial.html#layout-and-drawing
graph.layout(prog=kwargs.get("prog", "dot"))
graph.draw(kwargs.get("path", "graph.svg"))