Skip to main content

Custom-Models

· 8 min read

TL;DR

AutoGen现在支持自定义模型!这个功能使用户能够定义和加载自己的模型,从而实现更灵活和个性化的推理机制。通过遵循特定的协议,您可以将自定义模型集成到AutoGen中,并使用任何您想要的模型/API调用/硬编码响应来响应提示。

注意:根据您使用的模型,您可能需要调整Agent的默认提示

快速入门

一个交互式和简单的入门方法是按照这里的笔记本,将HuggingFace的本地模型加载到AutoGen中并用于推理,并对提供的类进行更改。

第一步:创建自定义模型客户端类

要开始在AutoGen中使用自定义模型,您需要创建一个模型客户端类,该类遵循client.py中定义的ModelClient协议。新的模型客户端类应实现以下方法:

  • create(): 返回一个实现了ModelClientResponseProtocol的响应对象(协议部分有更多细节)。
  • message_retrieval(): 处理响应对象并返回一个字符串列表或消息对象列表(协议部分有更多细节)。
  • cost(): 返回响应的成本。
  • get_usage(): 返回一个包含RESPONSE_USAGE_KEYS = ["prompt_tokens", "completion_tokens", "total_tokens", "cost", "model"]键的字典。

以下是一个简单的自定义类的示例:

class CustomModelClient:
def __init__(self, config, **kwargs):
print(f"CustomModelClient config: {config}")

def create(self, params):
num_of_responses = params.get("n", 1)

# 可以使用SimpleNamespace创建自己的数据响应类
# 这里为了简单起见,使用SimpleNamespace
# 只要它遵循ModelClientResponseProtocol即可

response = SimpleNamespace()
response.choices = []
response.model = "model_name" # 应与OAI_CONFIG_LIST注册匹配

for _ in range(num_of_responses):
text = "this is a dummy text response"
choice = SimpleNamespace()
choice.message = SimpleNamespace()
choice.message.content = text
choice.message.function_call = None
response.choices.append(choice)
return response

def message_retrieval(self, response):
choices = response.choices
return [choice.message.content for choice in choices]

def cost(self, response) -> float:
response.cost = 0
return 0

@staticmethod
def get_usage(response):
return {}

第二步:将配置添加到OAI_CONFIG_LIST中

必须设置 model_client_cls 字段为新类的名称(以字符串形式) "model_client_cls":"CustomModelClient"。其他字段将被转发到类的构造函数,因此您可以完全控制要指定的参数和如何使用它们。例如:

{
"model": "Open-Orca/Mistral-7B-OpenOrca",
"model_client_cls": "CustomModelClient",
"device": "cuda",
"n": 1,
"params": {
"max_length": 1000,
}
}

第三步:将新的自定义模型注册到将使用它的 Agent 中

如果在 Agent 的配置列表中添加了一个带有字段 "model_client_cls":"<class name>" 的配置,那么在创建 Agent 并在对话初始化之前,必须注册相应的模型以使用所需的类:

my_agent.register_model_client(model_client_cls=CustomModelClient, [other args that will be forwarded to CustomModelClient constructor])

model_client_cls=CustomModelClient 参数与 OAI_CONFIG_LIST 中指定的参数匹配,并且 CustomModelClient 是符合 ModelClient 协议的类(有关协议的更多详细信息,请参见下文)。

如果新的模型客户端在配置列表中,但在聊天初始化时未注册,则会引发错误。

协议详细信息

可以以多种方式创建自定义模型类,但需要遵守在 client.py 中定义的 ModelClient 协议和响应结构,如下所示。

响应协议目前使用了与 OpenAI 响应结构匹配的自动生成代码库中的最低要求字段。任何与 OpenAI 响应结构匹配的响应协议可能更加适应未来的更改,但我们从最低要求开始,以便更容易采用此功能。


class ModelClient(Protocol):
"""
客户端类必须实现以下方法:
- create 方法必须返回一个实现了 ModelClientResponseProtocol 的响应对象
- cost 方法必须返回响应的成本
- get_usage 方法必须返回一个包含以下键的字典:
- prompt_tokens
- completion_tokens
- total_tokens
- cost
- model

此类用于创建一个可以由 OpenAIWrapper 使用的客户端。
create 方法返回的响应必须遵循 ModelClientResponseProtocol,但可以根据需要进行扩展。
必须实现 message_retrieval 方法以返回响应中的字符串列表或消息列表。
"""

RESPONSE_USAGE_KEYS = ["prompt_tokens", "completion_tokens", "total_tokens", "cost", "model"]

class ModelClientResponseProtocol(Protocol):
class Choice(Protocol):
class Message(Protocol):
content: Optional[str]

message: Message

choices: List[Choice]
model: str

def create(self, params) -> ModelClientResponseProtocol:
...

def message_retrieval(
self, response: ModelClientResponseProtocol
) -> Union[List[str], List[ModelClient.ModelClientResponseProtocol.Choice.Message]]:
"""
检索并返回响应中的字符串列表或 Choice.Message 列表。

注意:如果返回 Choice.Message 列表,目前需要包含 OpenAI 的 ChatCompletion Message 对象的字段,
因为目前代码库中的函数或工具调用都期望这样,除非使用自定义代理。
"""
...

def cost(self, response: ModelClientResponseProtocol) -> float:
...

@staticmethod
def get_usage(response: ModelClientResponseProtocol) -> Dict:
"""使用 RESPONSE_USAGE_KEYS 返回响应的使用情况摘要。"""
...

故障排除步骤

如果某些功能无法正常工作,请按照以下检查清单进行排查:

  • 确保在创建自定义模型类时遵循了客户端协议和客户端响应协议
    • create() 方法:在 create 调用期间返回推理响应时,必须遵循 ModelClientResponseProtocol
    • message_retrieval() 方法:返回字符串列表或消息对象列表。如果返回消息对象列表,则当前必须包含 OpenAI 的 ChatCompletion 消息对象的字段,因为目前在代码库的其他函数或工具调用中都是这样期望的,除非使用了自定义代理。
    • cost() 方法:返回一个整数,如果不关心成本跟踪,可以返回 0
    • get_usage():返回一个字典,如果不关心使用情况跟踪,可以返回一个空字典 {}
  • 确保在 OAI_CONFIG_LIST 中有相应的条目,并且该条目具有 "model_client_cls":"<custom-model-class-name>" 字段。
  • 确保已使用相应的配置条目和新类注册了客户端 agent.register_model_client(model_client_cls=<class-of-custom-model>, [other optional args])
  • 确保在 OAI_CONFIG_LIST 中注册了所有定义的自定义模型。
  • 其他故障排除可能需要在自定义代码中进行。

结论

通过使用自定义模型的能力,AutoGen 现在为您的 AI 应用程序提供了更大的灵活性和功能。无论您是训练自己的模型还是想使用特定的预训练模型,AutoGen 都可以满足您的需求。祝您编码愉快!