36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363 | class KnowledgeGraphIndex(BaseIndex[KG]):
"""知识图谱索引。
通过提取三元组构建知识图谱,并在查询时利用知识图谱。
Args:
kg_triple_extract_template (BasePromptTemplate): 用于提取三元组的提示模板。
max_triplets_per_chunk (int): 要提取的最大三元组数量。
service_context (Optional[ServiceContext]): 要使用的服务上下文。
storage_context (Optional[StorageContext]): 要使用的存储上下文。
graph_store (Optional[GraphStore]): 要使用的图形存储。
show_progress (bool): 是否显示tqdm进度条。默认为False。
include_embeddings (bool): 是否在索引中包含嵌入。默认为False。
max_object_length (int): 三元组中对象的最大长度。默认为128。
kg_triplet_extract_fn (Optional[Callable]): 用于提取三元组的函数。默认为None。"""
index_struct_cls = KG
def __init__(
self,
nodes: Optional[Sequence[BaseNode]] = None,
objects: Optional[Sequence[IndexNode]] = None,
index_struct: Optional[KG] = None,
llm: Optional[LLM] = None,
embed_model: Optional[BaseEmbedding] = None,
storage_context: Optional[StorageContext] = None,
kg_triple_extract_template: Optional[BasePromptTemplate] = None,
max_triplets_per_chunk: int = 10,
include_embeddings: bool = False,
show_progress: bool = False,
max_object_length: int = 128,
kg_triplet_extract_fn: Optional[Callable] = None,
# deprecated
service_context: Optional[ServiceContext] = None,
**kwargs: Any,
) -> None:
"""初始化参数。"""
# need to set parameters before building index in base class.
self.include_embeddings = include_embeddings
self.max_triplets_per_chunk = max_triplets_per_chunk
self.kg_triple_extract_template = (
kg_triple_extract_template or DEFAULT_KG_TRIPLET_EXTRACT_PROMPT
)
# NOTE: Partially format keyword extract template here.
self.kg_triple_extract_template = (
self.kg_triple_extract_template.partial_format(
max_knowledge_triplets=self.max_triplets_per_chunk
)
)
self._max_object_length = max_object_length
self._kg_triplet_extract_fn = kg_triplet_extract_fn
self._llm = llm or llm_from_settings_or_context(Settings, service_context)
self._embed_model = embed_model or embed_model_from_settings_or_context(
Settings, service_context
)
super().__init__(
nodes=nodes,
index_struct=index_struct,
service_context=service_context,
storage_context=storage_context,
show_progress=show_progress,
objects=objects,
**kwargs,
)
# TODO: legacy conversion - remove in next release
if (
len(self.index_struct.table) > 0
and isinstance(self.graph_store, SimpleGraphStore)
and len(self.graph_store._data.graph_dict) == 0
):
logger.warning("Upgrading previously saved KG index to new storage format.")
self.graph_store._data.graph_dict = self.index_struct.rel_map
@property
def graph_store(self) -> GraphStore:
return self._graph_store
def as_retriever(
self,
retriever_mode: Optional[str] = None,
embed_model: Optional[BaseEmbedding] = None,
**kwargs: Any,
) -> BaseRetriever:
from llama_index.core.indices.knowledge_graph.retrievers import (
KGRetrieverMode,
KGTableRetriever,
)
if len(self.index_struct.embedding_dict) > 0 and retriever_mode is None:
retriever_mode = KGRetrieverMode.HYBRID
return KGTableRetriever(
self,
object_map=self._object_map,
llm=self._llm,
embed_model=embed_model or self._embed_model,
retriever_mode=retriever_mode,
**kwargs,
)
def _extract_triplets(self, text: str) -> List[Tuple[str, str, str]]:
if self._kg_triplet_extract_fn is not None:
return self._kg_triplet_extract_fn(text)
else:
return self._llm_extract_triplets(text)
def _llm_extract_triplets(self, text: str) -> List[Tuple[str, str, str]]:
"""从文本中提取关键词。"""
response = self._llm.predict(
self.kg_triple_extract_template,
text=text,
)
return self._parse_triplet_response(
response, max_length=self._max_object_length
)
@staticmethod
def _parse_triplet_response(
response: str, max_length: int = 128
) -> List[Tuple[str, str, str]]:
knowledge_strs = response.strip().split("\n")
results = []
for text in knowledge_strs:
if "(" not in text or ")" not in text or text.index(")") < text.index("("):
# skip empty lines and non-triplets
continue
triplet_part = text[text.index("(") + 1 : text.index(")")]
tokens = triplet_part.split(",")
if len(tokens) != 3:
continue
if any(len(s.encode("utf-8")) > max_length for s in tokens):
# We count byte-length instead of len() for UTF-8 chars,
# will skip if any of the tokens are too long.
# This is normally due to a poorly formatted triplet
# extraction, in more serious KG building cases
# we'll need NLP models to better extract triplets.
continue
subj, pred, obj = map(str.strip, tokens)
if not subj or not pred or not obj:
# skip partial triplets
continue
# Strip double quotes and Capitalize triplets for disambiguation
subj, pred, obj = (
entity.strip('"').capitalize() for entity in [subj, pred, obj]
)
results.append((subj, pred, obj))
return results
def _build_index_from_nodes(self, nodes: Sequence[BaseNode]) -> KG:
"""从节点构建索引。"""
# do simple concatenation
index_struct = self.index_struct_cls()
nodes_with_progress = get_tqdm_iterable(
nodes, self._show_progress, "Processing nodes"
)
for n in nodes_with_progress:
triplets = self._extract_triplets(
n.get_content(metadata_mode=MetadataMode.LLM)
)
logger.debug(f"> Extracted triplets: {triplets}")
for triplet in triplets:
subj, _, obj = triplet
self.upsert_triplet(triplet)
index_struct.add_node([subj, obj], n)
if self.include_embeddings:
triplet_texts = [str(t) for t in triplets]
embed_outputs = self._embed_model.get_text_embedding_batch(
triplet_texts, show_progress=self._show_progress
)
for rel_text, rel_embed in zip(triplet_texts, embed_outputs):
index_struct.add_to_embedding_dict(rel_text, rel_embed)
return index_struct
def _insert(self, nodes: Sequence[BaseNode], **insert_kwargs: Any) -> None:
"""插入一个文档。"""
for n in nodes:
triplets = self._extract_triplets(
n.get_content(metadata_mode=MetadataMode.LLM)
)
logger.debug(f"Extracted triplets: {triplets}")
for triplet in triplets:
subj, _, obj = triplet
triplet_str = str(triplet)
self.upsert_triplet(triplet)
self._index_struct.add_node([subj, obj], n)
if (
self.include_embeddings
and triplet_str not in self._index_struct.embedding_dict
):
rel_embedding = self._embed_model.get_text_embedding(triplet_str)
self._index_struct.add_to_embedding_dict(triplet_str, rel_embedding)
# Update the storage context's index_store
self._storage_context.index_store.add_index_struct(self._index_struct)
def upsert_triplet(
self, triplet: Tuple[str, str, str], include_embeddings: bool = False
) -> None:
"""插入三元组和可选的嵌入。
用于手动插入知识图谱三元组(以(主语,关系,客体)的形式)。
Args:
triplet(元组):知识三元组
embedding(任意,可选):三元组的嵌入选项。默认为None。
"""
self._graph_store.upsert_triplet(*triplet)
triplet_str = str(triplet)
if include_embeddings:
set_embedding = self._embed_model.get_text_embedding(triplet_str)
self._index_struct.add_to_embedding_dict(str(triplet), set_embedding)
self._storage_context.index_store.add_index_struct(self._index_struct)
def add_node(self, keywords: List[str], node: BaseNode) -> None:
"""添加节点。
用于手动插入节点(由关键字索引)。
Args:
keywords(List[str]):用于索引节点的关键字。
node(Node):要索引的节点。
"""
self._index_struct.add_node(keywords, node)
self._docstore.add_documents([node], allow_update=True)
def upsert_triplet_and_node(
self,
triplet: Tuple[str, str, str],
node: BaseNode,
include_embeddings: bool = False,
) -> None:
"""更新知识图谱三元组和节点。
调用upsert_triplet和add_node两个函数。
行为是幂等的;如果节点已经存在,只会添加三元组。
Args:
keywords(List[str]):用于索引节点的关键词。
node(Node):要被索引的节点。
include_embeddings(bool):是否添加三元组的嵌入选项。默认为False。
"""
subj, _, obj = triplet
self.upsert_triplet(triplet)
self.add_node([subj, obj], node)
triplet_str = str(triplet)
if include_embeddings:
set_embedding = self._embed_model.get_text_embedding(triplet_str)
self._index_struct.add_to_embedding_dict(str(triplet), set_embedding)
self._storage_context.index_store.add_index_struct(self._index_struct)
def _delete_node(self, node_id: str, **delete_kwargs: Any) -> None:
"""删除一个节点。"""
raise NotImplementedError("Delete is not supported for KG index yet.")
@property
def ref_doc_info(self) -> Dict[str, RefDocInfo]:
"""获取已摄取文档及其节点和元数据的字典映射。"""
node_doc_ids_sets = list(self._index_struct.table.values())
node_doc_ids = list(set().union(*node_doc_ids_sets))
nodes = self.docstore.get_nodes(node_doc_ids)
all_ref_doc_info = {}
for node in nodes:
ref_node = node.source_node
if not ref_node:
continue
ref_doc_info = self.docstore.get_ref_doc_info(ref_node.node_id)
if not ref_doc_info:
continue
all_ref_doc_info[ref_node.node_id] = ref_doc_info
return all_ref_doc_info
def get_networkx_graph(self, limit: int = 100) -> Any:
"""获取图结构的networkx表示形式。
Args:
limit(int):要包括在图中的起始节点数。
注意:此函数需要安装networkx。
注意:这是一个测试功能。
"""
try:
import networkx as nx
except ImportError:
raise ImportError(
"Please install networkx to visualize the graph: `pip install networkx`"
)
g = nx.Graph()
subjs = list(self.index_struct.table.keys())
# add edges
rel_map = self._graph_store.get_rel_map(subjs=subjs, depth=1, limit=limit)
added_nodes = set()
for keyword in rel_map:
for path in rel_map[keyword]:
subj = keyword
for i in range(0, len(path), 2):
if i + 2 >= len(path):
break
if subj not in added_nodes:
g.add_node(subj)
added_nodes.add(subj)
rel = path[i + 1]
obj = path[i + 2]
g.add_edge(subj, obj, label=rel, title=rel)
subj = obj
return g
@property
def query_context(self) -> Dict[str, Any]:
return {GRAPH_STORE_KEY: self._graph_store}
|