"""使用分层规划方法与OpenAPI API进行交互的代理程序。"""
import json
import re
from functools import partial
from typing import Any, Callable, Dict, List, Optional, cast
import yaml
from langchain_core.callbacks import BaseCallbackManager
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain_core.tools import BaseTool, Tool
from langchain_community.agent_toolkits.openapi.planner_prompt import (
API_CONTROLLER_PROMPT,
API_CONTROLLER_TOOL_DESCRIPTION,
API_CONTROLLER_TOOL_NAME,
API_ORCHESTRATOR_PROMPT,
API_PLANNER_PROMPT,
API_PLANNER_TOOL_DESCRIPTION,
API_PLANNER_TOOL_NAME,
PARSING_DELETE_PROMPT,
PARSING_GET_PROMPT,
PARSING_PATCH_PROMPT,
PARSING_POST_PROMPT,
PARSING_PUT_PROMPT,
REQUESTS_DELETE_TOOL_DESCRIPTION,
REQUESTS_GET_TOOL_DESCRIPTION,
REQUESTS_PATCH_TOOL_DESCRIPTION,
REQUESTS_POST_TOOL_DESCRIPTION,
REQUESTS_PUT_TOOL_DESCRIPTION,
)
from langchain_community.agent_toolkits.openapi.spec import ReducedOpenAPISpec
from langchain_community.llms import OpenAI
from langchain_community.tools.requests.tool import BaseRequestsTool
from langchain_community.utilities.requests import RequestsWrapper
#
# Requests tools with LLM-instructed extraction of truncated responses.
#
# Of course, truncating so bluntly may lose a lot of valuable
# information in the response.
# However, the goal for now is to have only a single inference step.
MAX_RESPONSE_LENGTH = 5000
"""返回的响应的最大长度。"""
def _get_default_llm_chain(prompt: BasePromptTemplate) -> Any:
from langchain.chains.llm import LLMChain
return LLMChain(
llm=OpenAI(),
prompt=prompt,
)
def _get_default_llm_chain_factory(
prompt: BasePromptTemplate,
) -> Callable[[], Any]:
"""返回一个默认的LLMChain工厂。"""
return partial(_get_default_llm_chain, prompt)
[docs]class RequestsPostToolWithParsing(BaseRequestsTool, BaseTool):
"""使用LLM指导的截断响应提取的请求POST工具。"""
name: str = "requests_post"
"""工具名称。"""
description = REQUESTS_POST_TOOL_DESCRIPTION
"""工具描述。"""
response_length: int = MAX_RESPONSE_LENGTH
"""返回的响应的最大长度。"""
llm_chain: Any = Field(
default_factory=_get_default_llm_chain_factory(PARSING_POST_PROMPT)
)
"""LLMChain 用于提取响应。"""
def _run(self, text: str) -> str:
from langchain.output_parsers.json import parse_json_markdown
try:
data = parse_json_markdown(text)
except json.JSONDecodeError as e:
raise e
response: str = cast(str, self.requests_wrapper.post(data["url"], data["data"]))
response = response[: self.response_length]
return self.llm_chain.predict(
response=response, instructions=data["output_instructions"]
).strip()
async def _arun(self, text: str) -> str:
raise NotImplementedError()
#
# Orchestrator, planner, controller.
#
def _create_api_planner_tool(
api_spec: ReducedOpenAPISpec, llm: BaseLanguageModel
) -> Tool:
from langchain.chains.llm import LLMChain
endpoint_descriptions = [
f"{name} {description}" for name, description, _ in api_spec.endpoints
]
prompt = PromptTemplate(
template=API_PLANNER_PROMPT,
input_variables=["query"],
partial_variables={"endpoints": "- " + "- ".join(endpoint_descriptions)},
)
chain = LLMChain(llm=llm, prompt=prompt)
tool = Tool(
name=API_PLANNER_TOOL_NAME,
description=API_PLANNER_TOOL_DESCRIPTION,
func=chain.run,
)
return tool
def _create_api_controller_agent(
api_url: str,
api_docs: str,
requests_wrapper: RequestsWrapper,
llm: BaseLanguageModel,
allow_dangerous_requests: bool,
) -> Any:
from langchain.agents.agent import AgentExecutor
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.chains.llm import LLMChain
get_llm_chain = LLMChain(llm=llm, prompt=PARSING_GET_PROMPT)
post_llm_chain = LLMChain(llm=llm, prompt=PARSING_POST_PROMPT)
tools: List[BaseTool] = [
RequestsGetToolWithParsing( # type: ignore[call-arg]
requests_wrapper=requests_wrapper,
llm_chain=get_llm_chain,
allow_dangerous_requests=allow_dangerous_requests,
),
RequestsPostToolWithParsing( # type: ignore[call-arg]
requests_wrapper=requests_wrapper,
llm_chain=post_llm_chain,
allow_dangerous_requests=allow_dangerous_requests,
),
]
prompt = PromptTemplate(
template=API_CONTROLLER_PROMPT,
input_variables=["input", "agent_scratchpad"],
partial_variables={
"api_url": api_url,
"api_docs": api_docs,
"tool_names": ", ".join([tool.name for tool in tools]),
"tool_descriptions": "\n".join(
[f"{tool.name}: {tool.description}" for tool in tools]
),
},
)
agent = ZeroShotAgent(
llm_chain=LLMChain(llm=llm, prompt=prompt),
allowed_tools=[tool.name for tool in tools],
)
return AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True)
def _create_api_controller_tool(
api_spec: ReducedOpenAPISpec,
requests_wrapper: RequestsWrapper,
llm: BaseLanguageModel,
allow_dangerous_requests: bool,
) -> Tool:
"""将控制器公开为工具。
使用规划器的计划调用该工具,并动态创建一个只包含相关文档的控制器代理,以限制上下文。
"""
base_url = api_spec.servers[0]["url"] # TODO: do better.
def _create_and_run_api_controller_agent(plan_str: str) -> str:
pattern = r"\b(GET|POST|PATCH|DELETE)\s+(/\S+)*"
matches = re.findall(pattern, plan_str)
endpoint_names = [
"{method} {route}".format(method=method, route=route.split("?")[0])
for method, route in matches
]
docs_str = ""
for endpoint_name in endpoint_names:
found_match = False
for name, _, docs in api_spec.endpoints:
regex_name = re.compile(re.sub("\{.*?\}", ".*", name))
if regex_name.match(endpoint_name):
found_match = True
docs_str += f"== Docs for {endpoint_name} == \n{yaml.dump(docs)}\n"
if not found_match:
raise ValueError(f"{endpoint_name} endpoint does not exist.")
agent = _create_api_controller_agent(
base_url, docs_str, requests_wrapper, llm, allow_dangerous_requests
)
return agent.run(plan_str)
return Tool(
name=API_CONTROLLER_TOOL_NAME,
func=_create_and_run_api_controller_agent,
description=API_CONTROLLER_TOOL_DESCRIPTION,
)
[docs]def create_openapi_agent(
api_spec: ReducedOpenAPISpec,
requests_wrapper: RequestsWrapper,
llm: BaseLanguageModel,
shared_memory: Optional[Any] = None,
callback_manager: Optional[BaseCallbackManager] = None,
verbose: bool = True,
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
allow_dangerous_requests: bool = False,
**kwargs: Any,
) -> Any:
"""构建一个针对给定规范的OpenAI API规划器和控制器。
通过requests_wrapper注入凭据。
我们使用一个顶层的“编排器”代理来调用规划器和控制器,而不是一个顶层的规划器
来调用其计划的控制器。这是为了保持规划器的简单性。
您需要将allow_dangerous_requests设置为True才能使用带有BaseRequestsTool的Agent。
请求可能是危险的,可能导致安全漏洞。
例如,用户可以要求服务器向内部服务器发出请求。建议通过代理服务器使用请求
并避免接受来自不受信任来源的输入而没有适当的沙箱环境。
请参阅:https://python.langchain.com/docs/security
获取更多安全信息。
"""
from langchain.agents.agent import AgentExecutor
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.chains.llm import LLMChain
tools = [
_create_api_planner_tool(api_spec, llm),
_create_api_controller_tool(
api_spec, requests_wrapper, llm, allow_dangerous_requests
),
]
prompt = PromptTemplate(
template=API_ORCHESTRATOR_PROMPT,
input_variables=["input", "agent_scratchpad"],
partial_variables={
"tool_names": ", ".join([tool.name for tool in tools]),
"tool_descriptions": "\n".join(
[f"{tool.name}: {tool.description}" for tool in tools]
),
},
)
agent = ZeroShotAgent(
llm_chain=LLMChain(llm=llm, prompt=prompt, memory=shared_memory),
allowed_tools=[tool.name for tool in tools],
**kwargs,
)
return AgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools,
callback_manager=callback_manager,
verbose=verbose,
**(agent_executor_kwargs or {}),
)