Skip to main content

AgentOptimizer:一种培训LLM Agent的主动方式

在Colab中打开 在GitHub上打开

AutoGen提供了由LLM驱动的可对话代理,可以通过自动聊天来执行任务,无论是工具还是人类。该框架通过多代理对话允许工具使用和人类参与。有关此功能的文档,请参阅此处

在传统的机器学习流程中,我们通过更新模型在训练集上的损失来训练模型的参数。而在LLM代理的时代,我们应该如何训练一个代理呢?在这里,我们迈出了训练代理的初始步骤。受到OpenAI提供的函数调用能力的启发,我们将模型参数与代理函数/技能进行类比,并根据代理在训练集上的历史表现来更新代理的函数/技能。作为一种主动训练代理的方式,我们的方法可以提高代理的能力,而无需访问LLM的参数。

在这个笔记本中,我们介绍了一个新的类,'AgentOptimizer',它能根据历史对话记录改进一个Assistant-UserProxy对的函数列表。这个功能将支持代理改进解决与之前任务相同类型的问题的能力。具体来说,给定一组训练数据,AgentOptimizer会迭代地提示LLM优化AssistantAgent和UserProxyAgent的现有函数列表,必要时进行代码实现。它还包括两种策略,回滚和提前停止,以简化训练过程。在示例场景中,我们测试了提出的AgentOptimizer在解决MATH数据集中的问题。

AgentOptimizer

更多信息可以在论文中找到。

作者:- Shaokun Zhang,宾夕法尼亚州立大学博士生

import copy
import json
import os
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union

from openai import BadRequestError

import autogen
from autogen import config_list_from_json
from autogen.agentchat import Agent
from autogen.agentchat.contrib.agent_optimizer import AgentOptimizer
from autogen.agentchat.contrib.math_user_proxy_agent import MathUserProxyAgent
from autogen.code_utils import extract_code
from autogen.math_utils import get_answer
import copy
import json
import os
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union

from openai import BadRequestError

import autogen
from autogen import config_list_from_json
from autogen.agentchat import Agent
from autogen.agentchat.contrib.agent_optimizer import AgentOptimizer
from autogen.agentchat.contrib.math_user_proxy_agent import MathUserProxyAgent
from autogen.code_utils import extract_code
from autogen.math_utils import get_answer

以上是一段 Python 代码,导入了一些模块和类。这段代码使用了 openai 模块中的 BadRequestError 异常类,以及 autogen 模块中的一些函数和类。

其中,autogen 模块是一个自动生成代码的工具,它包含了一些用于处理配置文件、代理聊天和数学计算的函数和类。在这段代码中,还导入了 autogen.agentchat 模块中的 Agent 类,以及 autogen.agentchat.contrib 模块中的 AgentOptimizerMathUserProxyAgent 类。

此外,还导入了 autogen.code_utils 模块中的 extract_code 函数和 autogen.math_utils 模块中的 get_answer 函数。这些函数和类的具体功能可以根据需要进行调用和使用。

使用 function_call 的 MathUserProxy

这个代理是一个定制的 MathUserProxy,继承自它的父类

它支持使用 function_call 和 python 来解决数学问题。

def is_termination_msg_mathchat(message):
"""检查消息是否为终止消息。"""
if isinstance(message, dict):
message = message.get("content")
if message is None:
return False
cb = extract_code(message)
contain_code = False
for c in cb:
if c[0] == "python":
contain_code = True
break
if message.rstrip().find("TERMINATE") >= 0:
return True
return not contain_code and get_answer(message) is not None and get_answer(message) != ""


class MathUserProxyAgent(MathUserProxyAgent):
MAX_CONSECUTIVE_AUTO_REPLY = 15
DEFAULT_REPLY = "继续。请继续解决问题,直到需要查询为止。(如果你得到了答案,请将其放在 \\boxed{} 中。)"
PROMPTS = """让我们解决一个数学问题。
查询要求:
输出时应使用 'print' 函数,并使用分数/根式形式而不是小数。
您可以使用像 sympy 这样的包来帮助您。
您必须按照以下格式编写代码:
```python
# your code
```
如果缺少某些包,您还可以建议安装相应的包。

请按照以下步骤进行:
1. 逐步解决问题(不要过度细分步骤)。
2. 将可以通过 Python 代码询问的任何查询(例如,可以计算的任何计算或方程式)和您在本次对话上下文中了解的函数提取出来。

请注意:
(1) 不要在一个步骤中混合建议的 Python 代码和函数调用。
(2) 您必须记住,您没有一个名为 "python" 的函数可用。

您必须按照以下格式编写您的 Python 代码:
```python
# your code
```

3. Wait for me to give the results or wait for the executed results of the function call.
4. Continue if you think the result is correct. If the result is invalid or unexpected, please correct your query or reasoning.

After all the queries are run and you get the answer, put the answer in \\boxed{}.

Problem:
"""

def __init__(
self,
name: Optional[str] = "MathChatAgent",
is_termination_msg: Optional[Callable[[Dict], bool]] = is_termination_msg_mathchat,
human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "NEVER",
default_auto_reply: Optional[Union[str, Dict, None]] = DEFAULT_REPLY,
max_invalid_q_per_step=3,
**kwargs,
):
super().__init__(
name=name,
is_termination_msg=is_termination_msg,
human_input_mode=human_input_mode,
default_auto_reply=default_auto_reply,
max_invalid_q_per_step=max_invalid_q_per_step,
**kwargs,
)
del self._reply_func_list[2]
self.register_reply([Agent, None], MathUserProxyAgent._generate_math_reply, position=4)
del self._reply_func_list[3]
self.register_reply(
trigger=autogen.ConversableAgent, reply_func=MathUserProxyAgent.generate_function_call_reply, position=3
)
self.register_reply(
trigger=autogen.ConversableAgent, reply_func=MathUserProxyAgent._check_final_result, position=0
)

self.max_function_call_trial = 3
self.query = None
self.answer = None
self.is_correct = None

def generate_function_call_reply(
self,
messages: Optional[List[Dict]] = None,
sender: Optional[autogen.ConversableAgent] = None,
config: Optional[Any] = None,
) -> Tuple[bool, Union[Dict, None]]:
"""Generate a reply using function call."""
if messages is None:
messages = self._oai_messages[sender]
message = messages[-1]
if "function_call" in message:
is_exec_success, func_return = self.execute_function(message["function_call"])
if is_exec_success:
self.max_function_call_trial = 3
return True, func_return
else:
if self.max_function_call_trial == 0:
error_message = func_return["content"]
self.max_function_call_trial = 3
return (
True,
"The func is executed failed many times. "
+ error_message
+ ". Please directly reply me with TERMINATE. We need to terminate the conversation.",
)
else:
revise_prompt = "You may make a wrong function call (It may due the arguments you provided doesn't fit the function arguments like missing required positional argument). \
If you think this error occurs due to you make a wrong function arguments input and you could make it success, please try to call this function again using the correct arguments. \
Otherwise, the error may be caused by the function itself. Please directly reply me with TERMINATE. We need to terminate the conversation. "
error_message = func_return["content"]
return True, "The func is executed failed." + error_message + revise_prompt
return False, None

def initiate_chat(
self,
recipient,
answer: None,
silent: Optional[bool] = False,
**context,
):
self.query = context["problem"]
if not isinstance(answer, str):
answer = str(answer)
if answer.endswith(".0"):
answer = answer[:-2]
self._answer = answer
else:
self._answer = answer

self.is_correct = None

self._prepare_chat(recipient, True)
error_message = None
try:
prompt = self.PROMPTS + context["problem"]
self.send(prompt, recipient, silent=silent)
except BadRequestError as e:
error_message = str(e)
self.is_correct = 0
print("error information: {}".format(error_message))

recipient.reset()
is_correct = copy.deepcopy(self.is_correct)
self._reset()
return is_correct

def _check_final_result(
self,
messages: Optional[List[Dict]] = None,
sender: Optional[autogen.Agent] = None,
config: Optional[Any] = None,
):

messages = messages[-1]
if isinstance(messages, dict):
messages = messages.get("content")
if messages is None:
return False, None

cb = extract_code(messages)
contain_code = False
for c in cb:
if c[0] == "python":
contain_code = True
break
if not contain_code and get_answer(messages) is not None and get_answer(messages) != "":
if get_answer(messages) == self._answer:
self.is_correct = 1
return True, "The result is Correct. Please reply me with TERMINATE."
else:
self.is_correct = 0
return False, None
else:
return False, None

def _reset(self):
super()._reset()
self.max_function_call_trial = 3
self.is_correct = None
self.query = None
self.answer = None

加载数据集

MATAH数据集包含12,500个具有挑战性的竞赛数学问题。MATH中的每个问题都有一个完整的逐步解决方案,可以用来教导模型生成答案的推导和解释。

我们严格遵循CRAFTtrain/test划分。请指定您自己的数据集路径。这里我们以代数问题的前10个作为示例。

test_data, train_data = [], []
with open("MATH/dataset/algebra.jsonl", "r", encoding="utf-8") as f:
for line in f:
test_data.append(json.loads(line))
with open("MATH/dataset/train/algebra.jsonl", "r", encoding="utf-8") as f:
for line in f:
train_data.append(json.loads(line))
test_data, train_data = test_data[0:10], train_data[0:10]

构建代理

构建用于解决这些问题的MathUserProxyAgent和AssistantAgent。在这里,我们使用gpt-4-1106-preview来构建AssistantAgent。

llm_config = {
"config_list": [
{
"model": "gpt-4-1106-preview",
"api_type": "azure",
"api_key": os.environ["AZURE_OPENAI_API_KEY"],
"base_url": "https://ENDPOINT.openai.azure.com/",
"api_version": "2023-07-01-preview",
}
]
}

assistant = autogen.AssistantAgent(
name="assistant",
system_message="You are a helpful assistant.",
llm_config=llm_config,
)
user_proxy = MathUserProxyAgent(
name="mathproxyagent",
human_input_mode="NEVER",
code_execution_config={"work_dir": "_output", "use_docker": False},
)

无代理优化的测试

以下是在没有代理优化过程的情况下获取性能的代码。

在这种情况下,AssistantAgent和MathUserProxyAgent没有任何函数调用,只是使用Python解决问题。

sum = 0
for index, query in enumerate(test_data):
is_correct = user_proxy.initiate_chat(recipient=assistant, answer=query["answer"], problem=query["question"])
print(is_correct)
sum += is_correct
success_rate_without_agent_training = sum / 10

代理训练

然后,我们使用AgentOptimizer来迭代优化代理,根据历史对话和性能优化函数调用。AgentOptimizer在每次迭代中产生register_for_llm和register_for_executor,然后分别用于更新assistant和user_proxy代理。在这里,我们为这两个代理进行了十个时期的优化。

EPOCH = 10
optimizer_model = "gpt-4-1106-preview"
optimizer = AgentOptimizer(max_actions_per_step=3, llm_config=llm_config, optimizer_model=optimizer_model)
for i in range(EPOCH):
for index, query in enumerate(train_data):
is_correct = user_proxy.initiate_chat(assistant, answer=query["answer"], problem=query["question"])
history = assistant.chat_messages_for_summary(user_proxy)
optimizer.record_one_conversation(history, is_satisfied=is_correct)
register_for_llm, register_for_exector = optimizer.step()
for item in register_for_llm:
assistant.update_function_signature(**item)
if len(register_for_exector.keys()) > 0:
user_proxy.register_function(function_map=register_for_exector)

在这段代码中,我们使用了一个名为 AgentOptimizer 的优化器来训练我们的助手模型。我们设置了 EPOCH 的值为 10,并选择了一个名为 "gpt-4-1106-preview" 的优化器模型。然后,我们使用一个循环来迭代训练数据集中的每个问题。对于每个问题,我们使用 user_proxy.initiate_chat 函数来与助手进行对话,并记录对话历史。然后,我们使用 optimizer.record_one_conversation 函数来记录这次对话的信息,包括是否满意。接下来,我们使用 optimizer.step 函数来执行一步优化,并获取需要更新的函数签名信息。最后,我们使用 assistant.update_function_signature 函数来更新助手的函数签名,并使用 user_proxy.register_function 函数来注册需要更新的函数。

经过代理优化的测试

在代理优化之后,代理通过10次优化迭代获得了一系列函数,如下所示。

然后我们展示了经过代理优化和未经代理优化的最终表现。我们观察到经过优化的代理明显更好。

sum = 0
for index, query in enumerate(test_data):
is_correct = user_proxy.initiate_chat(recipient=assistant, answer=query["answer"], problem=query["question"])
sum += is_correct
success_rate_with_agent_training = sum / 10
print(
"------------------------------------------------学到的函数------------------------------------------------"
)
for func in assistant.llm_config["functions"]:
print(func["name"] + ": " + func["description"] + "\n")
print("------------------------------------------------总结------------------------------------------------\n")
print("未经代理训练的成功率: {average}%\n".format(average=success_rate_without_agent_training * 100))
print("经过代理训练的成功率: {average}%\n".format(average=success_rate_with_agent_training * 100))
------------------------------------------------学到的函数------------------------------------------------
evaluate_expression: 对以字符串形式提供的算术或数学表达式进行求值。

calculate_compound_interest_principal: 计算以季度复利的方式实现特定未来价值所需的本金。

solve_linear_system: 解决以系数和变量表示的线性方程组。

------------------------------------------------总结------------------------------------------------

未经代理训练的成功率: 60.0%

经过代理训练的成功率: 90.0%