33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305 | class LATSAgentWorker(CustomSimpleAgentWorker):
"""执行语言代理树搜索步骤的代理工作程序。
源论文:https://arxiv.org/pdf/2310.04406v2.pdf。
持续迭代直到没有错误/任务完成。"""
num_expansions: int = Field(default=2, description="Number of expansions to do.")
reflection_prompt: PromptTemplate = Field(..., description="Reflection prompt.")
candiate_expansion_prompt: PromptTemplate = Field(
..., description="Candidate expansion prompt."
)
max_rollouts: int = Field(
default=5,
description=(
"Max rollouts. By default, -1 means that we keep going until the first solution is found."
),
)
chat_formatter: ReActChatFormatter = Field(
default_factory=ReActChatFormatter, description="Chat formatter."
)
def __init__(
self,
tools: List[BaseTool],
llm: Optional[LLM] = None,
num_expansions: int = 2,
max_rollouts: int = 5,
reflection_prompt: Optional[PromptTemplate] = None,
candiate_expansion_prompt: Optional[PromptTemplate] = None,
**kwargs: Any,
) -> None:
"""初始化参数。"""
# validate that all tools are query engine tools
llm = llm or Settings.llm
super().__init__(
tools=tools,
llm=llm,
num_expansions=num_expansions,
max_rollouts=max_rollouts,
reflection_prompt=reflection_prompt or DEFAULT_REFLECTION_PROMPT,
candiate_expansion_prompt=candiate_expansion_prompt
or DEFAULT_CANDIDATES_PROMPT,
**kwargs,
)
def _initialize_state(self, task: Task, **kwargs: Any) -> Dict[str, Any]:
"""初始化状态。"""
# initialize root node
root_node = SearchNode(
current_reasoning=[ObservationReasoningStep(observation=task.input)],
evaluation=Evaluation(score=1), # evaluation for root node is blank
)
return {"count": 0, "solution_queue": [], "root_node": root_node}
async def _arun_candidate(
self,
node: SearchNode,
task: Task,
) -> List[BaseReasoningStep]:
"""为给定节点生成候选节点。
通常情况下,我们对动作空间进行采样,以生成新的候选节点。
实际上,由于我们使用的是基于ReAct的代理,这意味着使用生成ReAct轨迹,运行一个工具。
"""
output_parser = ReActOutputParser()
# format react prompt
formatted_prompt = self.chat_formatter.format(
self.tools,
chat_history=task.memory.get(),
current_reasoning=node.current_reasoning,
)
# run LLM
response = await self.llm.achat(formatted_prompt)
# parse output into reasoning step
try:
reasoning_step = output_parser.parse(response.message.content)
except ValueError as e:
reasoning_step = ResponseReasoningStep(
thought=response.message.content,
response=f"Encountered an error parsing: {e!s}",
)
# get response or run tool
if reasoning_step.is_done:
reasoning_step = cast(ResponseReasoningStep, reasoning_step)
current_reasoning = [reasoning_step]
else:
reasoning_step = cast(ActionReasoningStep, reasoning_step)
tool_selection = ToolSelection(
tool_id=reasoning_step.action,
tool_name=reasoning_step.action,
tool_kwargs=reasoning_step.action_input,
)
try:
tool_output = await acall_tool_with_selection(
tool_selection, self.tools, verbose=self.verbose
)
except Exception as e:
tool_output = f"Encountered error: {e!s}"
observation_step = ObservationReasoningStep(observation=str(tool_output))
current_reasoning = [reasoning_step, observation_step]
return current_reasoning
async def _aevaluate(
self,
cur_node: SearchNode,
current_reasoning: List[BaseReasoningStep],
input: str,
) -> float:
"""评估。"""
all_reasoning = cur_node.current_reasoning + current_reasoning
history_str = "\n".join([s.get_content() for s in all_reasoning])
evaluation = await self.llm.astructured_predict(
Evaluation,
prompt=self.reflection_prompt,
query=input,
conversation_history=history_str,
)
if self.verbose:
print_text(
f"> Evaluation for input {input}\n: {evaluation}\n\n", color="pink"
)
return evaluation
async def _get_next_candidates(
self,
cur_node: SearchNode,
input: str,
) -> List[str]:
"""获取下一个候选者。"""
# get candidates
history_str = "\n".join([s.get_content() for s in cur_node.current_reasoning])
candidates = await self.llm.astructured_predict(
Candidates,
prompt=self.candiate_expansion_prompt,
query=input,
conversation_history=history_str,
num_candidates=self.num_expansions,
)
candidate_strs = candidates.candidates[: self.num_expansions]
if self.verbose:
print_text(f"> Got candidates: {candidate_strs}\n", color="yellow")
# ensure we have the right number of candidates
if len(candidate_strs) < self.num_expansions:
return (candidate_strs * self.num_expansions)[: self.num_expansions]
else:
return candidate_strs[: self.num_expansions]
def _update_state(
self,
node: SearchNode,
current_reasoning: List[BaseReasoningStep],
evaluation: Evaluation,
) -> SearchNode:
"""更新状态。"""
# create child node
new_node = SearchNode(
current_reasoning=node.current_reasoning + current_reasoning,
parent=node,
children=[],
evaluation=evaluation,
)
node.children.append(new_node)
# backpropagate the reward
new_node.backpropagate(evaluation.score)
return new_node
def _run_step(
self, state: Dict[str, Any], task: Task, input: Optional[str] = None
) -> Tuple[AgentChatResponse, bool]:
"""运行步骤。
返回:
元组(agent_response, is_done)
"""
return asyncio.run(self._arun_step(state, task, input))
async def _arun_step(
self, state: Dict[str, Any], task: Task, input: Optional[str] = None
) -> Tuple[AgentChatResponse, bool]:
"""运行步骤。
返回:
元组(agent_response, is_done)
"""
root_node = state["root_node"]
cur_node = root_node.get_best_leaf()
if self.verbose:
print_text(
f"> Selecting node to expand: {cur_node.answer}\n", color="green"
)
# expand the given node, generate n candidate nodes
new_candidates = await self._get_next_candidates(
cur_node,
task.input,
)
new_nodes = []
for candidate in new_candidates:
new_nodes.append(
self._update_state(
cur_node,
[ObservationReasoningStep(observation=candidate)],
Evaluation(score=1), # evaluation for candidate node is blank
)
)
# expand the given node, generate n candidates
# for each candidate, run tool, get response
solution_queue: List[SearchNode] = state["solution_queue"]
# first, generate the candidates
candidate_jobs = [
self._arun_candidate(new_node, task) for new_node in new_nodes
]
all_new_reasoning_steps = await asyncio.gather(*candidate_jobs)
if self.verbose:
for new_reasoning_steps in all_new_reasoning_steps:
out_txt = "\n".join([s.get_content() for s in new_reasoning_steps])
print_text(f"> Generated new reasoning step: {out_txt}\n", color="blue")
# then, evaluate the candidates
eval_jobs = [
self._aevaluate(new_node, new_reasoning_steps, task.input)
for new_node, new_reasoning_steps in zip(new_nodes, all_new_reasoning_steps)
]
evaluations = await asyncio.gather(*eval_jobs)
# then, update the state
for new_reasoning_steps, cur_new_node, evaluation in zip(
all_new_reasoning_steps, new_nodes, evaluations
):
new_node = self._update_state(cur_new_node, new_reasoning_steps, evaluation)
if new_node.is_done:
if self.verbose:
print_text(
f"> Found solution node: {new_node.answer}\n", color="cyan"
)
solution_queue.append(new_node)
# check if done
state["count"] += 1
if self.max_rollouts == -1 and solution_queue:
is_done = True
elif self.max_rollouts > 0 and state["count"] >= self.max_rollouts:
is_done = True
else:
is_done = False
# determine response
if solution_queue:
best_solution_node = max(solution_queue, key=lambda x: x.score)
response = best_solution_node.answer
else:
response = "I am still thinking."
if self.verbose:
print_text(f"> Got final response: {response!s}\n", color="green")
# return response
return AgentChatResponse(response=str(response)), is_done
def _finalize_task(self, state: Dict[str, Any], **kwargs) -> None:
"""完成任务。"""
|