26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171 | class MultiStepQueryEngine(BaseQueryEngine):
"""多步查询引擎。
该查询引擎可以在现有的基本查询引擎上操作,以及多步查询转换。
Args:
query_engine (BaseQueryEngine): 一个BaseQueryEngine对象。
query_transform (StepDecomposeQueryTransform): 一个StepDecomposeQueryTransform对象。
response_synthesizer (Optional[BaseSynthesizer]): 一个BaseSynthesizer对象。
num_steps (Optional[int]): 运行多步查询的步骤数。
early_stopping (bool): 如果停止函数返回True,则是否提前停止。
index_summary (str): 索引的字符串摘要。
stop_fn (Optional[Callable[[Dict], bool]]): 接受信息字典并返回布尔值的停止函数。"""
def __init__(
self,
query_engine: BaseQueryEngine,
query_transform: StepDecomposeQueryTransform,
response_synthesizer: Optional[BaseSynthesizer] = None,
num_steps: Optional[int] = 3,
early_stopping: bool = True,
index_summary: str = "None",
stop_fn: Optional[Callable[[Dict], bool]] = None,
) -> None:
self._query_engine = query_engine
self._query_transform = query_transform
self._response_synthesizer = response_synthesizer or get_response_synthesizer(
callback_manager=self._query_engine.callback_manager
)
self._index_summary = index_summary
self._num_steps = num_steps
self._early_stopping = early_stopping
# TODO: make interface to stop function better
self._stop_fn = stop_fn or default_stop_fn
# num_steps must be provided if early_stopping is False
if not self._early_stopping and self._num_steps is None:
raise ValueError("Must specify num_steps if early_stopping is False.")
callback_manager = self._query_engine.callback_manager
super().__init__(callback_manager)
def _get_prompt_modules(self) -> PromptMixinType:
"""获取提示子模块。"""
return {
"response_synthesizer": self._response_synthesizer,
"query_transform": self._query_transform,
}
def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
with self.callback_manager.event(
CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str}
) as query_event:
nodes, source_nodes, metadata = self._query_multistep(query_bundle)
final_response = self._response_synthesizer.synthesize(
query=query_bundle,
nodes=nodes,
additional_source_nodes=source_nodes,
)
final_response.metadata = metadata
query_event.on_end(payload={EventPayload.RESPONSE: final_response})
return final_response
async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
with self.callback_manager.event(
CBEventType.QUERY, payload={EventPayload.QUERY_STR: query_bundle.query_str}
) as query_event:
nodes, source_nodes, metadata = self._query_multistep(query_bundle)
final_response = await self._response_synthesizer.asynthesize(
query=query_bundle,
nodes=nodes,
additional_source_nodes=source_nodes,
)
final_response.metadata = metadata
query_event.on_end(payload={EventPayload.RESPONSE: final_response})
return final_response
def _combine_queries(
self, query_bundle: QueryBundle, prev_reasoning: str
) -> QueryBundle:
"""合并查询。"""
transform_metadata = {
"prev_reasoning": prev_reasoning,
"index_summary": self._index_summary,
}
return self._query_transform(query_bundle, metadata=transform_metadata)
def _query_multistep(
self, query_bundle: QueryBundle
) -> Tuple[List[NodeWithScore], List[NodeWithScore], Dict[str, Any]]:
"""运行查询组合器。"""
prev_reasoning = ""
cur_response = None
should_stop = False
cur_steps = 0
# use response
final_response_metadata: Dict[str, Any] = {"sub_qa": []}
text_chunks = []
source_nodes = []
while not should_stop:
if self._num_steps is not None and cur_steps >= self._num_steps:
should_stop = True
break
elif should_stop:
break
updated_query_bundle = self._combine_queries(query_bundle, prev_reasoning)
# TODO: make stop logic better
stop_dict = {"query_bundle": updated_query_bundle}
if self._stop_fn(stop_dict):
should_stop = True
break
cur_response = self._query_engine.query(updated_query_bundle)
# append to response builder
cur_qa_text = (
f"\nQuestion: {updated_query_bundle.query_str}\n"
f"Answer: {cur_response!s}"
)
text_chunks.append(cur_qa_text)
for source_node in cur_response.source_nodes:
source_nodes.append(source_node)
# update metadata
final_response_metadata["sub_qa"].append(
(updated_query_bundle.query_str, cur_response)
)
prev_reasoning += (
f"- {updated_query_bundle.query_str}\n" f"- {cur_response!s}\n"
)
cur_steps += 1
nodes = [
NodeWithScore(node=TextNode(text=text_chunk)) for text_chunk in text_chunks
]
return nodes, source_nodes, final_response_metadata
|