构建自定义Agent¶
在本手册中,我们将向您展示如何使用LlamaIndex构建自定义Agent。
构建自定义Agent的最简单方法是简单地对CustomSimpleAgentWorker
进行子类化,并实现一些必需的函数。您可以完全灵活地定义Agent的逐步逻辑。
这使您能够在RAG管道的基础上添加任意复杂的推理逻辑。
我们将向您展示如何构建一个简单的Agent,它在RouterQueryEngine的基础上添加了一个重试层,使其可以重试查询直到任务完成。我们将其构建在SQL工具和向量索引查询工具的基础之上。即使工具出现错误或只回答了问题的一部分,Agent也可以继续重试问题直到任务完成。
注意: 任何文本到SQL应用程序都应意识到执行任意SQL查询可能存在安全风险。建议采取必要的预防措施,例如使用受限角色、只读数据库、沙盒等。
设置自定义Agent¶
这里我们设置自定义Agent。
复习¶
在LlamaIndex中,一个Agent包括Agent运行器和Agent工作者。Agent运行器是一个编排者,负责存储像内存这样的状态,而Agent工作者控制任务的逐步执行。Agent运行器包括顺序执行和并行执行。更多细节可以在我们的低级API指南中找到。
大多数核心Agent逻辑(例如ReAct,函数调用循环)可以在Agent工作者中执行。因此,我们已经很容易地将Agent工作者设置为子类,让您可以将其插入到任何Agent运行器中。
创建自定义Agent工作者子类¶
如上所述,我们将CustomSimpleAgentWorker
设置为子类。这是一个已经为您设置了一些脚手架的类。这包括能够接收工具、回调、LLM,并确保状态/步骤被正确格式化。与此同时,您主要需要实现以下函数:
_initialize_state
_run_step
_finalize_task
一些额外的注意事项:
- 如果您希望支持Agent中的异步聊天,也可以实现
_arun_step
。 - 只要将所有剩余的args、kwargs传递给
super()
,您可以选择重写__init__
。 CustomSimpleAgentWorker
被实现为Pydantic的BaseModel
,这意味着您也可以定义自己的自定义属性。
以下是每个CustomSimpleAgentWorker
上的完整基本属性集(在构建自定义Agent时需要/可以传递的):
tools: Sequence[BaseTool]
tool_retriever: Optional[ObjectRetriever[BaseTool]]
llm: LLM
callback_manager: CallbackManager
verbose: bool
请注意,tools
和tool_retriever
是互斥的,您只能传递其中一个(例如,定义一个静态工具列表或定义一个可调用函数,在给定用户消息时返回相关工具)。您可以调用get_tools(message: str)
来返回给定消息的相关工具。
在定义自定义Agent时,所有这些属性都可以通过self
访问。
%pip install llama-index-readers-wikipedia
%pip install llama-index-llms-openai
from llama_index.core.agent import (
CustomSimpleAgentWorker,
Task,
AgentChatResponse,
)
from typing import Dict, Any, List, Tuple, Optional
from llama_index.core.tools import BaseTool, QueryEngineTool
from llama_index.core.program import LLMTextCompletionProgram
from llama_index.core.output_parsers import PydanticOutputParser
from llama_index.core.query_engine import RouterQueryEngine
from llama_index.core import ChatPromptTemplate, PromptTemplate
from llama_index.core.selectors import PydanticSingleSelector
from llama_index.core.bridge.pydantic import Field, BaseModel
在这里,我们定义了一些辅助变量和方法。例如,用于检测错误的提示模板以及在Pydantic中使用的响应格式。
DEFAULT_PROMPT_STR = """给定先前的问题/响应对,请确定响应中是否发生了错误,并建议一个修改后的问题,不会触发错误。修改后的问题示例:- 修改问题本身以引出非错误响应- 用上下文补充问题,以帮助下游系统更好地回答问题。- 用负面响应的示例或其他负面问题来补充问题。错误意味着要么触发了异常,要么响应与问题完全不相关。请以以下JSON格式返回对响应的评估。"""def get_chat_prompt_template( system_prompt: str, current_reasoning: Tuple[str, str]) -> ChatPromptTemplate: system_msg = ChatMessage(role=MessageRole.SYSTEM, content=system_prompt) messages = [system_msg] for raw_msg in current_reasoning: if raw_msg[0] == "user": messages.append( ChatMessage(role=MessageRole.USER, content=raw_msg[1]) ) else: messages.append( ChatMessage(role=MessageRole.ASSISTANT, content=raw_msg[1]) ) return ChatPromptTemplate(message_templates=messages)class ResponseEval(BaseModel): """评估响应是否存在错误。""" has_error: bool = Field( ..., description="响应是否存在错误。" ) new_question: str = Field(..., description="建议的新问题。") explanation: str = Field( ..., description=( "错误的解释以及新问题的解释。" "可以包括直接的堆栈跟踪。" ), )
from llama_index.core.bridge.pydantic import PrivateAttrclass RetryAgentWorker(CustomSimpleAgentWorker): """在路由器顶部添加重试层的代理工作器。 继续迭代直到没有错误/任务完成。 """ prompt_str: str = Field(default=DEFAULT_PROMPT_STR) max_iterations: int = Field(default=10) _router_query_engine: RouterQueryEngine = PrivateAttr() def __init__(self, tools: List[BaseTool], **kwargs: Any) -> None: """初始化参数。""" # 验证所有工具是否为查询引擎工具 for tool in tools: if not isinstance(tool, QueryEngineTool): raise ValueError( f"工具 {tool.metadata.name} 不是查询引擎工具。" ) self._router_query_engine = RouterQueryEngine( selector=PydanticSingleSelector.from_defaults(), query_engine_tools=tools, verbose=kwargs.get("verbose", False), ) super().__init__( tools=tools, **kwargs, ) def _initialize_state(self, task: Task, **kwargs: Any) -> Dict[str, Any]: """初始化状态。""" return {"count": 0, "current_reasoning": []} def _run_step( self, state: Dict[str, Any], task: Task, input: Optional[str] = None ) -> Tuple[AgentChatResponse, bool]: """运行步骤。 返回: 代理响应和是否完成的元组 """ if "new_input" not in state: new_input = task.input else: new_input = state["new_input"] # 首先运行路由器查询引擎 response = self._router_query_engine.query(new_input) # 追加到当前推理 state["current_reasoning"].extend( [("user", new_input), ("assistant", str(response))] ) # 然后,检查错误 # 根据模板动态创建用于结构化输出提取的pydantic程序 chat_prompt_tmpl = get_chat_prompt_template( self.prompt_str, state["current_reasoning"] ) llm_program = LLMTextCompletionProgram.from_defaults( output_parser=PydanticOutputParser(output_cls=ResponseEval), prompt=chat_prompt_tmpl, llm=self.llm, ) # 运行程序,查看结果 response_eval = llm_program( query_str=new_input, response_str=str(response) ) if not response_eval.has_error: is_done = True else: is_done = False state["new_input"] = response_eval.new_question if self.verbose: print(f"> 问题:{new_input}") print(f"> 响应:{response}") print(f"> 响应评估:{response_eval.dict()}") # 返回响应 return AgentChatResponse(response=str(response)), is_done def _finalize_task(self, state: Dict[str, Any], **kwargs) -> None: """完成任务。""" # 这里没有需要完成的内容 # 通常用于修改任何内部状态,超出`_initialize_state`中设置的内容 pass
设置数据和工具¶
我们为每个城市设置了SQL工具和向量索引工具。
from llama_index.core.tools import QueryEngineTool
设置SQL数据库 + 工具¶
from sqlalchemy import ( create_engine, # 创建引擎 MetaData, # 元数据 Table, # 表 Column, # 列 String, # 字符串 Integer, # 整数 select, # 查询 column, # 列)from llama_index.core import SQLDatabase # 导入SQLDatabase类engine = create_engine("sqlite:///:memory:", future=True) # 创建引擎metadata_obj = MetaData() # 创建元数据对象# 创建城市SQL表table_name = "city_stats"city_stats_table = Table( table_name, metadata_obj, Column("city_name", String(16), primary_key=True), # 城市名称 Column("population", Integer), # 人口 Column("country", String(16), nullable=False), # 国家)metadata_obj.create_all(engine) # 在引擎上创建所有表
from sqlalchemy import insert
rows = [
{"city_name": "Toronto", "population": 2930000, "country": "Canada"},
{"city_name": "Tokyo", "population": 13960000, "country": "Japan"},
{"city_name": "Berlin", "population": 3645000, "country": "Germany"},
]
for row in rows:
stmt = insert(city_stats_table).values(**row)
with engine.begin() as connection:
cursor = connection.execute(stmt)
from llama_index.core.query_engine import NLSQLTableQueryEngine
sql_database = SQLDatabase(engine, include_tables=["city_stats"])
sql_query_engine = NLSQLTableQueryEngine(
sql_database=sql_database, tables=["city_stats"], verbose=True
)
sql_tool = QueryEngineTool.from_defaults(
query_engine=sql_query_engine,
description=(
"Useful for translating a natural language query into a SQL query over"
" a table containing: city_stats, containing the population/country of"
" each city"
),
)
设置矢量工具¶
from llama_index.readers.wikipedia import WikipediaReader
from llama_index.core import VectorStoreIndex
cities = ["Toronto", "Berlin", "Tokyo"]
wiki_docs = WikipediaReader().load_data(pages=cities)
# 为每个城市构建一个单独的向量索引# 您也可以选择定义一个跨所有文档的单个向量索引,并通过元数据为每个块进行注释vector_tools = []for city, wiki_doc in zip(cities, wiki_docs): vector_index = VectorStoreIndex.from_documents([wiki_doc]) vector_query_engine = vector_index.as_query_engine() vector_tool = QueryEngineTool.from_defaults( query_engine=vector_query_engine, description=f"用于回答关于{city}的语义问题", ) vector_tools.append(vector_tool)
import random
class CustomAgent:
def __init__(self, actions):
self.actions = actions
def get_action(self, state):
return random.choice(self.actions)
from llama_index.llms.openai import OpenAI
llm = OpenAI(model="gpt-4")
callback_manager = llm.callback_manager
query_engine_tools = [sql_tool] + vector_tools
agent_worker = RetryAgentWorker.from_tools(
query_engine_tools,
llm=llm,
verbose=True,
callback_manager=callback_manager,
)
agent = agent_worker.as_agent(callback_manager=callback_manager)
尝试一些查询¶
response = agent.chat("Which countries are each city from?")
print(str(response))
Selecting query engine 0: The choice is about translating a natural language query into a SQL query over a table containing city_stats, which likely includes information about the country of each city.. > Table desc str: Table 'city_stats' has columns: city_name (VARCHAR(16)), population (INTEGER), country (VARCHAR(16)), and foreign keys: . > Predicted SQL query: SELECT city_name, country FROM city_stats > Question: Which countries are each city from? > Response: The city of Toronto is from Canada, Tokyo is from Japan, and Berlin is from Germany. > Response eval: {'has_error': True, 'new_question': 'Which country is each of the following cities from: Toronto, Tokyo, Berlin?', 'explanation': 'The original question was too vague as it did not specify which cities the question was referring to. The new question provides specific cities for which the country of origin is being asked.'} Selecting query engine 0: This choice is relevant because it mentions a table containing city_stats, which likely includes information about the country of each city.. > Table desc str: Table 'city_stats' has columns: city_name (VARCHAR(16)), population (INTEGER), country (VARCHAR(16)), and foreign keys: . > Predicted SQL query: SELECT city_name, country FROM city_stats WHERE city_name IN ('Toronto', 'Tokyo', 'Berlin') > Question: Which country is each of the following cities from: Toronto, Tokyo, Berlin? > Response: Toronto is from Canada, Tokyo is from Japan, and Berlin is from Germany. > Response eval: {'has_error': False, 'new_question': '', 'explanation': ''} Toronto is from Canada, Tokyo is from Japan, and Berlin is from Germany.
response = agent.chat(
"What are the top modes of transporation fo the city with the higehest population?"
)
print(str(response))
Selecting query engine 0: The question is asking about the top modes of transportation for the city with the highest population. Choice (1) is the most relevant because it mentions a table containing city_stats, which likely includes information about the population of each city.. > Table desc str: Table 'city_stats' has columns: city_name (VARCHAR(16)), population (INTEGER), country (VARCHAR(16)), and foreign keys: . > Predicted SQL query: SELECT city_name, population, mode_of_transportation FROM city_stats WHERE population = (SELECT MAX(population) FROM city_stats) ORDER BY mode_of_transportation ASC LIMIT 5; > Question: What are the top modes of transporation fo the city with the higehest population? > Response: I'm sorry, but there was an error in retrieving the information. Please try again later. > Response eval: {'has_error': True, 'new_question': 'What are the top modes of transportation for the city with the highest population?', 'explanation': 'The original question had spelling errors which might have caused the system to not understand the question correctly. The corrected question should now be clear and understandable for the system.'} Selecting query engine 0: The first choice is the most relevant because it mentions translating a natural language query into a SQL query over a table containing city_stats, which likely includes information about the population of each city.. > Table desc str: Table 'city_stats' has columns: city_name (VARCHAR(16)), population (INTEGER), country (VARCHAR(16)), and foreign keys: . > Predicted SQL query: SELECT city_name, population, country FROM city_stats WHERE population = (SELECT MAX(population) FROM city_stats) > Question: What are the top modes of transportation for the city with the highest population? > Response: The city with the highest population is Tokyo, Japan with a population of 13,960,000. > Response eval: {'has_error': True, 'new_question': 'What are the top modes of transportation for Tokyo, Japan?', 'explanation': 'The assistant failed to answer the original question correctly. The response was about the city with the highest population, but it did not mention anything about the top modes of transportation in that city. The new question directly asks about the top modes of transportation in Tokyo, Japan, which is the city with the highest population.'} Selecting query engine 3: The question specifically asks about Tokyo, and choice (4) is about answering semantic questions about Tokyo.. > Question: What are the top modes of transportation for Tokyo, Japan? > Response: The top modes of transportation for Tokyo, Japan are trains and subways, which are considered clean and efficient. Tokyo has an extensive network of electric train lines and over 900 train stations. Buses, monorails, and trams also play a secondary role in the public transportation system. Additionally, expressways connect Tokyo to other points in the Greater Tokyo Area and beyond. Taxis and long-distance ferries are also available for transportation within the city and to the surrounding islands. > Response eval: {'has_error': True, 'new_question': 'What are the top modes of transportation for Tokyo, Japan?', 'explanation': 'The original question was not answered correctly because the assistant did not provide information on the top modes of transportation for the city with the highest population. The new question directly asks for the top modes of transportation for Tokyo, Japan, which is the city with the highest population.'} Selecting query engine 3: Tokyo is mentioned in choice 4. > Question: What are the top modes of transportation for Tokyo, Japan? > Response: The top modes of transportation for Tokyo, Japan are trains and subways, which are considered clean and efficient. Tokyo has an extensive network of electric train lines and over 900 train stations. Buses, monorails, and trams also play a secondary role in public transportation within the city. Additionally, Tokyo has two major airports, Narita International Airport and Haneda Airport, which offer domestic and international flights. Expressways and taxis are also available for transportation within the city. > Response eval: {'has_error': True, 'new_question': 'What are the top modes of transportation for Tokyo, Japan?', 'explanation': 'The response is erroneous because it does not answer the question asked. The question asks for the top modes of transportation in the city with the highest population, but the response only provides the population of the city. The new question directly asks for the top modes of transportation in Tokyo, Japan, which is the city with the highest population.'} Selecting query engine 3: The question specifically asks about Tokyo, and choice 4 is about answering semantic questions about Tokyo.. > Question: What are the top modes of transportation for Tokyo, Japan? > Response: The top modes of transportation for Tokyo, Japan are trains and subways, which are considered clean and efficient. Tokyo has an extensive network of electric train lines and over 900 train stations. Buses, monorails, and trams also play a secondary role in public transportation within the city. Additionally, Tokyo has two major airports, Narita International Airport and Haneda Airport, which offer domestic and international flights. Expressways and taxis are also available for transportation within the city. > Response eval: {'has_error': False, 'new_question': '', 'explanation': ''} The top modes of transportation for Tokyo, Japan are trains and subways, which are considered clean and efficient. Tokyo has an extensive network of electric train lines and over 900 train stations. Buses, monorails, and trams also play a secondary role in public transportation within the city. Additionally, Tokyo has two major airports, Narita International Airport and Haneda Airport, which offer domestic and international flights. Expressways and taxis are also available for transportation within the city.
print(str(response))
The top modes of transportation for Tokyo, Japan are trains and subways, which are considered clean and efficient. Tokyo has an extensive network of electric train lines and over 900 train stations. Buses, monorails, and trams also play a secondary role in public transportation within the city. Additionally, Tokyo has two major airports, Narita International Airport and Haneda Airport, which offer domestic and international flights. Expressways and taxis are also available for transportation within the city.
response = agent.chat("What are the sports teams of each city in Asia?")
print(str(response))
Selecting query engine 3: The question is asking about sports teams in Asia, and Tokyo is located in Asia.. > Question: What are the sports teams of each city in Asia? > Response: I'm sorry, but the context information does not provide a comprehensive list of sports teams in each city in Asia. It only mentions some sports teams in Tokyo, Japan. To get a complete list of sports teams in each city in Asia, you would need to consult a reliable source or conduct further research. > Response eval: {'has_error': True, 'new_question': 'What are some popular sports teams in Tokyo, Japan?', 'explanation': 'The original question is too broad and requires extensive data that the system may not possess. The new question is more specific and focuses on a single city, making it more likely to receive a correct and comprehensive answer.'} Selecting query engine 3: The question specifically asks about Tokyo, and choice 4 is about answering semantic questions about Tokyo.. > Question: What are some popular sports teams in Tokyo, Japan? > Response: Some popular sports teams in Tokyo, Japan include the Yomiuri Giants and Tokyo Yakult Swallows in baseball, F.C. Tokyo and Tokyo Verdy 1969 in soccer, and Hitachi SunRockers, Toyota Alvark Tokyo, and Tokyo Excellence in basketball. Tokyo is also known for its sumo wrestling tournaments held at the Ryōgoku Kokugikan sumo arena. > Response eval: {'has_error': False, 'new_question': '', 'explanation': ''} Some popular sports teams in Tokyo, Japan include the Yomiuri Giants and Tokyo Yakult Swallows in baseball, F.C. Tokyo and Tokyo Verdy 1969 in soccer, and Hitachi SunRockers, Toyota Alvark Tokyo, and Tokyo Excellence in basketball. Tokyo is also known for its sumo wrestling tournaments held at the Ryōgoku Kokugikan sumo arena.