import json
import re
from abc import abstractmethod
from typing import Any, Dict, List, Optional, Union
from langchain.base_language import BaseLanguageModel
from langchain.chains import LLMChain
from langchain.tools.base import BaseTool
from langchain_core.callbacks.manager import Callbacks
from langchain_core.prompts.chat import (
AIMessagePromptTemplate,
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain_experimental.pydantic_v1 import BaseModel
DEMONSTRATIONS = [
{
"role": "user",
"content": "please show me a video and an image of (based on the text) 'a boy is running' and dub it", # noqa: E501 是用来告诉 linter 忽略超过 501 个字符的行长度限制的注释。 表示忽略超过每行字符限制的警告。 表示忽略PEP 8中的行长度限制。 表示忽略 PEP 8 中的行长度限制。
},
{
"role": "assistant",
"content": '[{{"task": "video_generator", "id": 0, "dep": [-1], "args": {{"prompt": "a boy is running" }}}}, {{"task": "text_reader", "id": 1, "dep": [-1], "args": {{"text": "a boy is running" }}}}, {{"task": "image_generator", "id": 2, "dep": [-1], "args": {{"prompt": "a boy is running" }}}}]', # noqa: E501 是用来告诉 linter 忽略超过 501 个字符的行长度限制的注释。 表示忽略超过每行字符限制的警告。 表示忽略PEP 8中的行长度限制。 表示忽略 PEP 8 中的行长度限制。
},
{
"role": "user",
"content": "Give you some pictures e1.jpg, e2.png, e3.jpg, help me count the number of sheep?", # noqa: E501 是用来告诉 linter 忽略超过 501 个字符的行长度限制的注释。 表示忽略超过每行字符限制的警告。 表示忽略PEP 8中的行长度限制。 表示忽略 PEP 8 中的行长度限制。
},
{
"role": "assistant",
"content": '[ {{"task": "image_qa", "id": 0, "dep": [-1], "args": {{"image": "e1.jpg", "question": "How many sheep in the picture"}}}}, {{"task": "image_qa", "id": 1, "dep": [-1], "args": {{"image": "e2.jpg", "question": "How many sheep in the picture"}}}}, {{"task": "image_qa", "id": 2, "dep": [-1], "args": {{"image": "e3.jpg", "question": "How many sheep in the picture"}}}}]', # noqa: E501 是用来告诉 linter 忽略超过 501 个字符的行长度限制的注释。 表示忽略超过每行字符限制的警告。 表示忽略PEP 8中的行长度限制。 表示忽略 PEP 8 中的行长度限制。
},
]
[docs]class TaskPlaningChain(LLMChain):
"""用于执行任务的链。"""
[docs] @classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
demos: List[Dict] = DEMONSTRATIONS,
verbose: bool = True,
) -> LLMChain:
"""获取响应解析器。"""
system_template = """# 1 任务规划阶段:AI助手可以将用户输入解析为多个任务:[{{"task": task, "id": task_id, "dep": dependency_task_id, "args": {{"input name": text may contain <resource-dep_id>}}}}]。特殊标签"dep_id"指的是依赖任务中生成的文本/图像/音频资源(请考虑依赖任务是否生成了这种类型的资源),"dep_id"必须在"dep"列表中。"dep"字段表示前置任务的id,这些任务生成了当前任务依赖的新资源。任务必须从以下工具中选择(以及工具描述、输入名称和输出类型):{tools}。可能会有多个相同类型的任务。逐步考虑解决用户请求所需的所有任务。解析出尽可能少的任务,同时确保可以解决用户请求。注意任务之间的依赖关系和顺序。如果无法解析用户输入,则需要回复空的JSON []。""" # noqa: E501 是用来告诉 linter 忽略超过 501 个字符的行长度限制的注释。 表示忽略超过每行字符限制的警告。 表示忽略PEP 8中的行长度限制。 表示忽略 PEP 8 中的行长度限制。
human_template = """现在我输入:{input}。"""
system_message_prompt = SystemMessagePromptTemplate.from_template(
system_template
)
human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)
demo_messages: List[
Union[HumanMessagePromptTemplate, AIMessagePromptTemplate]
] = []
for demo in demos:
if demo["role"] == "user":
demo_messages.append(
HumanMessagePromptTemplate.from_template(demo["content"])
)
else:
demo_messages.append(
AIMessagePromptTemplate.from_template(demo["content"])
)
# 将消息添加到demo_messages列表中
prompt = ChatPromptTemplate.from_messages(
[system_message_prompt, *demo_messages, human_message_prompt]
)
return cls(prompt=prompt, llm=llm, verbose=verbose)
[docs]class Step:
"""计划中的一步。"""
[docs] def __init__(
self, task: str, id: int, dep: List[int], args: Dict[str, str], tool: BaseTool
):
self.task = task
self.id = id
self.dep = dep
self.args = args
self.tool = tool
[docs]class Plan:
"""一个执行的计划。"""
[docs] def __init__(self, steps: List[Step]):
self.steps = steps
def __str__(self) -> str:
return str([str(step) for step in self.steps])
def __repr__(self) -> str:
return str(self)
[docs]class BasePlanner(BaseModel):
"""规划器的基类。"""
[docs] @abstractmethod
def plan(self, inputs: dict, callbacks: Callbacks = None, **kwargs: Any) -> Plan:
"""根据输入,决定要做什么。"""
[docs] @abstractmethod
async def aplan(
self, inputs: dict, callbacks: Callbacks = None, **kwargs: Any
) -> Plan:
"""异步 给定输入,决定要做什么。"""
[docs]class PlanningOutputParser(BaseModel):
"""解析规划阶段的输出。"""
[docs] def parse(self, text: str, hf_tools: List[BaseTool]) -> Plan:
"""解析规划阶段的输出。
Args:
text: 规划阶段的输出。
hf_tools: 可用的工具。
Returns:
规划结果。
"""
steps = []
for v in json.loads(re.findall(r"\[.*\]", text)[0]):
choose_tool = None
for tool in hf_tools:
if tool.name == v["task"]:
choose_tool = tool
break
if choose_tool:
steps.append(Step(v["task"], v["id"], v["dep"], v["args"], tool))
return Plan(steps=steps)
[docs]class TaskPlanner(BasePlanner):
"""任务规划器。"""
llm_chain: LLMChain
output_parser: PlanningOutputParser
stop: Optional[List] = None
[docs] def plan(self, inputs: dict, callbacks: Callbacks = None, **kwargs: Any) -> Plan:
"""根据输入,决定要做什么。"""
inputs["tools"] = [
f"{tool.name}: {tool.description}" for tool in inputs["hf_tools"]
]
llm_response = self.llm_chain.run(**inputs, stop=self.stop, callbacks=callbacks)
return self.output_parser.parse(llm_response, inputs["hf_tools"])
[docs] async def aplan(
self, inputs: dict, callbacks: Callbacks = None, **kwargs: Any
) -> Plan:
"""异步 给定输入,决定要执行什么操作。"""
inputs["hf_tools"] = [
f"{tool.name}: {tool.description}" for tool in inputs["hf_tools"]
]
llm_response = await self.llm_chain.arun(
**inputs, stop=self.stop, callbacks=callbacks
)
return self.output_parser.parse(llm_response, inputs["hf_tools"])
[docs]def load_chat_planner(llm: BaseLanguageModel) -> TaskPlanner:
"""加载聊天计划。"""
llm_chain = TaskPlaningChain.from_llm(llm)
return TaskPlanner(llm_chain=llm_chain, output_parser=PlanningOutputParser())