12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118 | class MarkdownNodeParser(NodeParser):
"""Markdown节点解析器。
使用自定义的Markdown拆分逻辑将文档拆分为节点。
Args:
include_metadata(布尔值):是否在节点中包含元数据
include_prev_next_rel(布尔值):是否包含上一个/下一个关系"""
@classmethod
def from_defaults(
cls,
include_metadata: bool = True,
include_prev_next_rel: bool = True,
callback_manager: Optional[CallbackManager] = None,
) -> "MarkdownNodeParser":
callback_manager = callback_manager or CallbackManager([])
return cls(
include_metadata=include_metadata,
include_prev_next_rel=include_prev_next_rel,
callback_manager=callback_manager,
)
@classmethod
def class_name(cls) -> str:
"""获取类名。"""
return "MarkdownNodeParser"
def _parse_nodes(
self,
nodes: Sequence[BaseNode],
show_progress: bool = False,
**kwargs: Any,
) -> List[BaseNode]:
all_nodes: List[BaseNode] = []
nodes_with_progress = get_tqdm_iterable(nodes, show_progress, "Parsing nodes")
for node in nodes_with_progress:
nodes = self.get_nodes_from_node(node)
all_nodes.extend(nodes)
return all_nodes
def get_nodes_from_node(self, node: BaseNode) -> List[TextNode]:
"""从文档中获取节点。"""
text = node.get_content(metadata_mode=MetadataMode.NONE)
markdown_nodes = []
lines = text.split("\n")
metadata: Dict[str, str] = {}
code_block = False
current_section = ""
for line in lines:
if line.lstrip().startswith("```"):
code_block = not code_block
header_match = re.match(r"^(#+)\s(.*)", line)
if header_match and not code_block:
if current_section != "":
markdown_nodes.append(
self._build_node_from_split(
current_section.strip(), node, metadata
)
)
metadata = self._update_metadata(
metadata, header_match.group(2), len(header_match.group(1).strip())
)
current_section = f"{header_match.group(2)}\n"
else:
current_section += line + "\n"
markdown_nodes.append(
self._build_node_from_split(current_section.strip(), node, metadata)
)
return markdown_nodes
def _update_metadata(
self, headers_metadata: dict, new_header: str, new_header_level: int
) -> dict:
"""更新元数据的Markdown标题。
删除所有等于或小于新发现标题级别的标题。
"""
updated_headers = {}
for i in range(1, new_header_level):
key = f"Header_{i}"
if key in headers_metadata:
updated_headers[key] = headers_metadata[key]
updated_headers[f"Header_{new_header_level}"] = new_header
return updated_headers
def _build_node_from_split(
self,
text_split: str,
node: BaseNode,
metadata: dict,
) -> TextNode:
"""从单个文本拆分构建节点。"""
node = build_nodes_from_splits([text_split], node, id_func=self.id_func)[0]
if self.include_metadata:
node.metadata = {**node.metadata, **metadata}
return node
|