教程:使用 ChatModel 定制 GenAI 模型

快速发展的生成式人工智能(GenAI)领域既带来了令人兴奋的机遇,也带来了整合挑战。为了有效利用最新的GenAI进展,开发者需要一个在灵活性和标准化之间取得平衡的框架。MLflow通过在 版本2.11.0 中引入的 mlflow.pyfunc.ChatModel 类满足了这一需求,为GenAI应用提供了一致的接口,同时简化了部署和测试。

在 ChatModel 和 PythonModel 之间选择

在 MLflow 中构建 GenAI 应用程序时,选择一个既能平衡易用性又能满足你所需定制化程度的正确模型抽象至关重要。MLflow 为此提供了两个主要类:mlflow.pyfunc.ChatModelmlflow.pyfunc.PythonModel。每个类都有其自身的优势和权衡,因此理解哪个类最适合你的用例至关重要。

何时使用 ChatModel

  • 简化接口: mlflow.pyfunc.ChatModel 提供了一个专为对话式AI应用设计的简化接口。它遵循与OpenAI等流行GenAI服务兼容的标准化输入输出格式,确保部署的一致性。

  • 标准化:该模型类型强制执行广泛采用的 OpenAI API 规范,通过减少手动处理复杂输入模式的需求,简化了模型部署和集成。

  • 快速开始:如果你的目标是快速开始并进行最小化设置,mlflow.pyfunc.ChatModel 是一个极佳的选择。它抽象了许多复杂性,让你能够专注于应用程序逻辑,而不是管理详细的模型签名。

  • 较少自定义:这种简单性的权衡是更严格的结构。mlflow.pyfunc.ChatModel 在你使用案例与标准化接口非常匹配时是理想的,但如果你需要偏离规定的输入-输出模式,它可能会限制你。

何时使用 PythonModel

  • 完全自定义mlflow.pyfunc.PythonModel 提供了对模型输入、输出和处理逻辑的完全控制。这使得它在构建高度定制化的应用程序或与不遵循标准化API的模型和服务集成时成为首选。

  • 复杂集成:如果你的应用程序需要复杂的数据处理、多步骤的数据转换,或与不符合标准模式的独特API集成,mlflow.pyfunc.PythonModel 提供了处理这些任务所需的灵活性。

  • 增加的复杂性:然而,随着灵活性的增加,复杂性也随之增加。使用 mlflow.pyfunc.PythonModel 需要你定义和管理模型的输入和输出签名,这在处理GenAI用例中常见的JSON结构时可能会更具挑战性。

关键考虑因素

  • ChatModel 优点: 简单性、标准化、更快的部署、更少的代码管理。

  • ChatModel 缺点: 灵活性有限,标准化的输入可能无法满足所有自定义需求。

  • PythonModel 优点: 高度可定制,可以处理任何输入/输出格式,适应复杂需求。

  • PythonModel 缺点: 需要更多设置,在定义自定义签名时可能更容易出错,需要仔细管理输入转换。

推荐:当你需要一个快速、标准化且可靠的解决方案来构建与流行GenAI接口相符的对话代理时,使用 mlflow.pyfunc.ChatModel。当你的项目需要灵活性并能够自定义模型的每一个行为时,选择 mlflow.pyfunc.PythonModel

本教程的目的

本教程将指导您使用 MLflow 的 mlflow.pyfunc.ChatModel 类创建自定义聊天代理的过程。

在本教程结束时,您将:

先决条件

  • 熟悉MLflow日志记录API和GenAI概念。

  • 使用 mlflow.pyfunc.ChatModel 需要安装 MLflow 版本 2.11.0 或更高。

  • MLflow 版本 2.14.0 或更高版本已安装,以便使用 MLflow 追踪

本教程使用 Databricks Foundation Model APIs 纯粹作为一个与外部服务接口的示例。您可以轻松地将提供者示例切换为使用任何托管的 LLM 托管服务(Amazon BedrockAzure AI StudioOpenAIAnthropic 等)。

核心概念

Tracing Customization for GenAI


MLflow 追踪 允许你监控和记录模型方法的执行,在调试和性能优化期间提供有价值的见解。

在我们的示例 BasicAgent 实现中,我们利用了两个独立的API来启动跟踪跨度:装饰器API和流畅API。

Decorator API

@mlflow.trace
def _get_system_message(self, role: str) -> Dict:
    if role not in self.models:
        raise ValueError(f"Unknown role: {role}")

    instruction = self.models[role]["instruction"]
    return ChatMessage(role="system", content=instruction).to_dict()

使用 @mlflow.trace 追踪装饰器是将追踪功能添加到函数和方法的最简单方式。默认情况下,由该装饰器生成的跨度将使用函数的名称作为跨度的名称。可以按照以下方式覆盖此命名,以及其他与跨度相关的参数:

@mlflow.trace(name="custom_span_name", attributes={"key": "value"}, span_type="func")
def _get_system_message(self, role: str) -> Dict:
    if role not in self.models:
        raise ValueError(f"Unknown role: {role}")

    instruction = self.models[role]["instruction"]
    return ChatMessage(role="system", content=instruction).to_dict()

小技巧

始终建议为生成的任何跨度设置一个人类可读的名称,特别是如果您正在检测私有或通用命名的函数或方法。MLflow Trace UI 默认会显示函数或方法的名称,如果您的函数和方法名称模糊,可能会导致混淆。

Fluent API

对于需要完全控制每个跨度数据日志记录的方面时,流畅API 上下文处理程序的实现用于启动跨度是非常有用的。

以下是我们应用程序中的一个示例,用于确保我们在通过 load_context 方法加载模型时捕获设置的参数。我们正在从实例属性 self.models_configself.models 中提取数据来设置 span 的属性。

with mlflow.start_span("Audit Agent") as root_span:
    root_span.set_inputs(messages)
    attributes = {**params.to_dict(), **self.models_config, **self.models}
    root_span.set_attributes(attributes)
    # More span manipulation...

Traces in the MLflow UI

在运行包含这些用于跟踪跨度生成和检测的组合使用模式的示例之后,

Agent 示例在 MLflow UI 中的追踪

示例中的关键类和方法

  • BasicAgent: 我们自定义的聊天代理类,扩展了 ChatModel

  • _get_system_message: 获取特定角色的系统消息配置。

  • _get_agent_response`: 向一个端点发送消息并获取响应。

  • _call_agent: 管理代理角色之间的对话流程。

  • _prepare_message_list`: 准备发送的消息列表。

  • load_context: 初始化模型上下文和配置。

  • predict`: 处理聊天模型的预测逻辑。

在这些方法中,load_contextpredict 方法覆盖了 ChatModel 基类的抽象实现。为了定义 ChatModel 的子类,你必须实现(至少)``predict`` 方法。load_context 方法仅在你实现(如下所述)自定义加载逻辑时使用,其中模型对象需要加载静态配置才能工作,或者需要执行额外的依赖逻辑以使对象实例化正常工作。

自定义 ChatModel 的示例

在下面的完整示例中,我们通过子类化 mlflow.pyfunc.ChatModel 来创建一个自定义聊天代理。这个名为 BasicAgent 的代理利用了几个重要的特性,这些特性有助于简化 GenAI 应用程序的开发、部署和跟踪。通过子类化 ChatModel,我们确保了处理对话代理的一致接口,同时也避免了与更通用的模型相关的常见陷阱。

下面的实现突出了以下关键方面:

  • 追踪:我们利用 MLflow 的追踪功能,通过装饰器和 fluent API 上下文处理程序两种方法来跟踪和记录关键操作。

    • 装饰器 API:用于轻松追踪 _get_agent_response_call_agent 等方法,以实现自动跨度创建。

    • Fluent API:提供了对跨度创建的细粒度控制,如在 predict 方法中所示,用于在代理交互期间审计关键输入和输出。

    • 提示: 我们确保在 MLflow 跟踪 UI 中以及通过客户端 API 获取记录的跟踪时,使用人类可读的跨度名称,以便于调试。

  • 自定义配置

    • 模型配置:通过在模型记录期间传递自定义配置(使用 model_config 参数),我们将模型行为与硬编码值解耦。这使得无需修改源代码即可快速测试不同的代理配置。

    • load_context 方法: 确保配置在运行时加载,使用必要的设置初始化代理,并防止由于缺少配置导致的运行时失败。

    • 提示: 我们避免在 load_context 中直接设置未定义的实例属性。相反,所有属性都在类构造函数中用默认值初始化,以确保我们的模型正确加载。

  • 对话管理

    • 我们使用 _get_system_message_get_agent_response_call_agent 等方法实现了一个多步骤的代理交互模式。这些方法管理多个代理之间的通信流程,例如“oracle”和“judge”角色,每个角色都配置了特定的指令和参数。

    • 静态输入/输出结构:通过遵循 ChatModel 所需的输入(List[ChatMessage])和输出(ChatResponse)格式,我们消除了与转换 JSON 或表格数据相关的复杂性,这在更通用的模型如 PythonModel 中很常见。

  • 常见陷阱避免

    • 通过输入示例进行模型验证:在模型记录期间,我们提供一个输入示例,允许 MLflow 验证输入接口并在早期捕捉结构问题,从而减少部署期间的调试时间。

import mlflow
from mlflow.types.llm import ChatResponse, ChatMessage, ChatParams, ChatChoice
from mlflow.pyfunc import ChatModel
from mlflow import deployments
from typing import List, Optional, Dict


class BasicAgent(ChatModel):
    def __init__(self):
        """Initialize the BasicAgent with placeholder values."""
        self.deploy_client = None
        self.models = {}
        self.models_config = {}
        self.conversation_history = []

    def load_context(self, context):
        """Initialize the connectors and model configurations."""
        self.deploy_client = deployments.get_deploy_client("databricks")
        self.models = context.model_config.get("models", {})
        self.models_config = context.model_config.get("configuration", {})

    def _get_system_message(self, role: str) -> Dict:
        """
        Get the system message configuration for the specified role.

        Args:
            role (str): The role of the agent (e.g., "oracle" or "judge").

        Returns:
            dict: The system message for the given role.
        """
        if role not in self.models:
            raise ValueError(f"Unknown role: {role}")

        instruction = self.models[role]["instruction"]
        return ChatMessage(role="system", content=instruction).to_dict()

    @mlflow.trace(name="Raw Agent Response")
    def _get_agent_response(
        self, message_list: List[Dict], endpoint: str, params: Optional[dict] = None
    ) -> Dict:
        """
        Call the agent endpoint to get a response.

        Args:
            message_list (List[Dict]): List of messages for the agent.
            endpoint (str): The agent's endpoint.
            params (Optional[dict]): Additional parameters for the call.

        Returns:
            dict: The response from the agent.
        """
        response = self.deploy_client.predict(
            endpoint=endpoint, inputs={"messages": message_list, **(params or {})}
        )
        return response["choices"][0]["message"]

    @mlflow.trace(name="Agent Call")
    def _call_agent(
        self, message: ChatMessage, role: str, params: Optional[dict] = None
    ) -> Dict:
        """
        Prepares and sends the request to a specific agent based on the role.

        Args:
            message (ChatMessage): The message to be processed.
            role (str): The role of the agent (e.g., "oracle" or "judge").
            params (Optional[dict]): Additional parameters for the call.

        Returns:
            dict: The response from the agent.
        """
        system_message = self._get_system_message(role)
        message_list = self._prepare_message_list(system_message, message)

        # Fetch agent response
        agent_config = self.models[role]
        response = self._get_agent_response(
            message_list, agent_config["endpoint"], params
        )

        # Update conversation history
        self.conversation_history.extend([message.to_dict(), response])
        return response

    @mlflow.trace(name="Assemble Conversation")
    def _prepare_message_list(
        self, system_message: Dict, user_message: ChatMessage
    ) -> List[Dict]:
        """
        Prepare the list of messages to send to the agent.

        Args:
            system_message (dict): The system message dictionary.
            user_message (ChatMessage): The user message.

        Returns:
            List[dict]: The complete list of messages to send.
        """
        user_prompt = {
            "role": "user",
            "content": self.models_config.get(
                "user_response_instruction", "Can you make the answer better?"
            ),
        }
        if self.conversation_history:
            return [system_message, *self.conversation_history, user_prompt]
        else:
            return [system_message, user_message.to_dict()]

    def predict(
        self, context, messages: List[ChatMessage], params: Optional[ChatParams] = None
    ) -> ChatResponse:
        """
        Predict method to handle agent conversation.

        Args:
            context: The MLflow context.
            messages (List[ChatMessage]): List of messages to process.
            params (Optional[ChatParams]): Additional parameters for the conversation.

        Returns:
            ChatResponse: The structured response object.
        """
        # Use the fluent API context handler to have added control over what is included in the span
        with mlflow.start_span(name="Audit Agent") as root_span:
            # Add the user input to the root span
            root_span.set_inputs(messages)

            # Add attributes to the root span
            attributes = {**params.to_dict(), **self.models_config, **self.models}
            root_span.set_attributes(attributes)

            # Initiate the conversation with the oracle
            oracle_params = self._get_model_params("oracle")
            oracle_response = self._call_agent(messages[0], "oracle", oracle_params)

            # Process the response with the judge
            judge_params = self._get_model_params("judge")
            judge_response = self._call_agent(
                ChatMessage(**oracle_response), "judge", judge_params
            )

            # Reset the conversation history and return the final response
            self.conversation_history = []

            output = ChatResponse(
                choices=[ChatChoice(index=0, message=ChatMessage(**judge_response))],
                usage={},
                model=judge_params.get("endpoint", "unknown"),
            )

            root_span.set_outputs(output)

        return output

    def _get_model_params(self, role: str) -> dict:
        """
        Retrieves model parameters for a given role.

        Args:
            role (str): The role of the agent (e.g., "oracle" or "judge").

        Returns:
            dict: A dictionary of parameters for the agent.
        """
        role_config = self.models.get(role, {})

        return {
            "temperature": role_config.get("temperature", 0.5),
            "max_tokens": role_config.get("max_tokens", 500),
        }

既然我们已经定义了模型,那么在记录模型之前,只需要完成一个步骤:我们需要定义模型的配置,以便对其进行初始化。这是通过定义我们的 model_config 配置来完成的。

设置我们的 model_config

在记录模型之前,我们需要定义控制模型代理行为的配置。这种将配置与模型核心逻辑分离的做法,使我们能够轻松测试和比较不同代理行为,而无需修改模型实现。通过使用灵活的配置系统,我们可以高效地试验不同的设置,从而更容易迭代和微调我们的模型。

为什么要解耦配置?

在生成式人工智能(GenAI)的背景下,代理行为可能会根据给定的指令集和参数(如``temperature``或``max_tokens``)而有很大差异。如果我们直接将这些配置硬编码到模型的逻辑中,每次新测试都需要更改模型的源代码,这将导致:

  • 低效性:每次测试都需要更改源代码,这减缓了实验过程。

  • 增加错误风险:不断修改源代码会增加引入错误或意外副作用的可能性。

  • 缺乏可重复性:如果没有明确的代码和配置分离,追踪和重现用于特定结果的确切配置变得困难。

通过通过 model_config 参数从外部设置这些值,我们使模型变得灵活,能够适应不同的测试场景。这种方法还能与 MLflow 的评估工具无缝集成,例如 mlflow.evaluate(),它允许你系统地比较不同配置下的模型输出。

定义模型配置

配置由两个主要部分组成:

  1. 模型:本节定义了特定于代理的配置,例如本例中的 judgeoracle 角色。每个代理具有:

    • 端点:指定用于此代理的模型类型或服务。

    • 指令:定义代理的角色和职责(例如,回答问题、评估回应)。

    • 温度和最大令牌数: 控制生成变异性 (temperature) 和响应的令牌限制。

  2. 通用配置:模型整体行为的附加设置,例如如何为后续代理构建用户响应。

备注

有两种选项可用于设置模型配置:直接在日志代码中(如下所示)或在本地位置编写 yaml 格式的配置文件,该文件路径可以在定义日志时的 model_config 参数中指定。要了解更多关于 model_config 参数的使用方式,请参阅模型配置使用指南

以下是我们如何为我们的代理设置配置:

model_config = {
    "models": {
        "judge": {
            "endpoint": "databricks-meta-llama-3-1-405b-instruct",
            "instruction": (
                "You are an evaluator of answers provided by others. Based on the context of both the question and the answer, "
                "provide a corrected answer if it is incorrect; otherwise, enhance the answer with additional context and explanation."
            ),
            "temperature": 0.5,
            "max_tokens": 2000,
        },
        "oracle": {
            "endpoint": "databricks-mixtral-8x7b-instruct",
            "instruction": (
                "You are a knowledgeable source of information that excels at providing detailed, but brief answers to questions. "
                "Provide an answer to the question based on the information provided."
            ),
            "temperature": 0.9,
            "max_tokens": 5000,
        },
    },
    "configuration": {
        "user_response_instruction": "Can you evaluate and enhance this answer with the provided contextual history?"
    },
}

外部配置的好处

  • 灵活性: 解耦的配置允许我们轻松切换或调整模型行为,而无需修改核心逻辑。例如,我们可以更改模型的指令或调整 temperature 以测试响应中不同程度的创造力。

  • 可扩展性: 随着更多的代理被添加到系统中或新的角色被引入,我们可以在不使模型代码混乱的情况下扩展此配置。这种分离使得代码库更整洁、更易于维护。

  • 可重复性和比较:通过保持配置外部化,我们可以使用 MLflow 记录每次运行中使用的具体设置。这使得结果的再现和不同实验的比较变得更加容易,确保了稳健的评估和裁决过程,以选择性能最佳的配置。

配置完成后,我们现在可以记录模型并使用这些设置运行实验。通过利用 MLflow 强大的跟踪和记录功能,我们将能够高效地管理实验,并从代理的响应中提取有价值的见解。

定义一个输入示例

在记录我们的模型之前,提供一个 input_example 来展示如何与模型交互是非常重要的。这个示例有几个关键用途:

  • 记录时的验证:包含一个 input_example 允许 MLflow 在记录过程中使用此示例执行 predict 方法。这有助于验证您的模型能否处理预期的输入格式,并尽早发现任何问题。

  • UI 表示input_example 在 MLflow UI 中显示在模型的工件下。这为用户提供了一个方便的参考,以便在交互部署的模型时理解预期的输入结构。

通过提供一个输入示例,您可以确保您的模型使用真实数据进行测试,从而增加在部署时其行为符合预期的信心。

小技巧

在使用 mlflow.pyfunc.ChatModel 定义您的 GenAI 应用程序时,如果没有提供输入示例,将使用默认的占位符输入示例。如果您在 MLflow UI 的工件查看器中注意到一个不熟悉或通用的输入示例,很可能是系统分配的默认占位符。为了避免这种情况,请确保在保存模型时指定一个自定义输入示例。

以下是我们将使用的输入示例:

input_example = {
    "messages": [
        {
            "role": "user",
            "content": "What is a good recipe for baking scones that doesn't require a lot of skill?",
        }
    ]
}

这个例子展示了一个用户请求简易司康饼食谱的情况。它符合我们的 BasicAgent 模型预期的输入结构,该模型处理一系列消息,每条消息包含一个 角色内容

提供输入示例的好处:

  • 执行与验证: MLflow 在记录时会将这个 input_example 传递给模型的 predict 方法,以确保它能够无误地处理输入。任何输入处理的问题,如数据类型错误或字段缺失,都将在此阶段被捕获,从而节省您后续调试的时间。

  • 用户界面显示: input_example 将在 MLflow UI 的模型工件视图部分中可见。这有助于用户理解模型期望的输入数据格式,使其在模型部署后更容易与模型交互。

  • 部署信心: 通过预先使用示例输入验证模型,您可以获得额外的保证,即模型将在生产环境中正确运行,从而降低部署后出现意外行为的风险。

包含一个 input_example 是一个简单但强大的步骤,用于验证您的模型是否已准备好部署,并且在接收到用户输入时将按预期行为运行。

记录和加载我们的自定义代理

要使用 MLflow 记录和加载模型,请使用:

with mlflow.start_run():
    model_info = mlflow.pyfunc.log_model(
        "model",
        python_model=BasicAgent(),
        model_config=model_config,
        input_example=input_example,
    )

loaded = mlflow.pyfunc.load_model(model_info.model_uri)

response = loaded.predict(
    {
        "messages": [
            {
                "role": "user",
                "content": "What is the best material to make a baseball bat out of?",
            }
        ]
    }
)

结论

在本教程中,您已经探索了使用 MLflow 的 mlflow.pyfunc.ChatModel 类创建自定义 GenAI 聊天代理的过程。我们演示了如何实现一种灵活、可扩展且标准化的方法来管理 GenAI 应用程序的部署,使您能够利用 AI 的最新进展,即使是对于那些在 MLflow 中尚未原生支持的库和框架。

通过使用 ChatModel 而不是更通用的 PythonModel ,您可以避免许多与部署 GenAI 相关的常见陷阱,利用不可变签名接口的优势,这些接口在您部署的任何 GenAI 接口中都是一致的,从而通过提供一致的体验来简化所有解决方案的使用。

本教程的关键要点包括:

  • 追踪与监控:通过将追踪直接集成到模型中,您可以深入了解应用程序的内部工作原理,使调试和优化更加直接。装饰器和流畅API方法都提供了管理关键操作追踪的多功能方式。

  • 灵活的配置管理:将配置与模型代码解耦,确保您可以在不修改源代码的情况下快速测试和迭代。这种方法不仅简化了实验过程,还随着应用程序的发展提高了可重复性和可扩展性。

  • 标准化输入和输出结构:利用 ChatModel 的静态签名简化了部署和服务 GenAI 模型的复杂性。通过遵循既定标准,您减少了通常与集成和验证输入/输出格式相关的摩擦。

  • 避免常见陷阱:在整个实现过程中,我们强调了避免常见问题的最佳实践,例如妥善处理密钥、验证输入示例以及理解加载上下文的细微差别。遵循这些实践可以确保您的模型在生产环境中保持安全、健壮和可靠。

  • 验证和部署准备: 在部署之前验证您的模型的重要性再怎么强调也不为过。通过使用工具如 mlflow.models.validate_serving_input(),您可以尽早发现并解决潜在的部署问题,从而在生产部署过程中节省时间和精力。

随着生成式人工智能领域的不断发展,构建适应性强且标准化的模型将对于利用未来几个月和几年内将解锁的令人兴奋且强大的功能至关重要。本教程中介绍的方法为您提供了一个强大的框架,用于在MLflow中集成和管理生成式AI技术,使您能够轻松地开发、跟踪和部署复杂的AI解决方案。

我们鼓励您扩展和定制这个基础示例,以满足您的特定需求并探索进一步的增强。通过利用 MLflow 不断增长的能力,您可以继续完善您的 GenAI 模型,确保它们在任何应用中都能提供有影响力且可靠的结果。