跳至主内容

注册模型

在本教程中,为简化操作我们将使用本地跟踪服务器和模型注册表。但对于生产环境,我们建议使用远程跟踪服务器

步骤0:安装依赖项

pip install --upgrade mlflow

步骤1:注册模型

要使用MLflow模型注册表,您需要将您的MLflow模型添加到其中。这可以通过以下任一命令注册给定模型来完成:

  • mlflow.<model_flavor>.log_model(registered_model_name=<model_name>): 在将模型记录到跟踪服务器的同时同时注册该模型。
  • mlflow.register_model(<model_uri>, <model_name>): 在将模型记录到跟踪服务器之后注册该模型。请注意,您需要先记录模型才能运行此命令以获取模型URI。

MLflow支持多种模型风格。在下面的示例中,我们将使用scikit-learn的RandomForestRegressor来演示注册模型的最简单方法,但请注意您可以使用任何支持的模型风格。 在下面的代码片段中,我们启动一个mlflow运行并训练一个随机森林模型。然后记录一些相关的超参数、模型的均方误差(MSE),最后记录并注册模型本身。

from sklearn.datasets import make_regression
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split

import mlflow
import mlflow.sklearn

with mlflow.start_run() as run:
X, y = make_regression(n_features=4, n_informative=2, random_state=0, shuffle=False)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)

params = {"max_depth": 2, "random_state": 42}
model = RandomForestRegressor(**params)
model.fit(X_train, y_train)

# Log parameters and metrics using the MLflow APIs
mlflow.log_params(params)

y_pred = model.predict(X_test)
mlflow.log_metrics({"mse": mean_squared_error(y_test, y_pred)})

# Log the sklearn model and register as version 1
mlflow.sklearn.log_model(
sk_model=model,
artifact_path="sklearn-model",
input_example=X_train,
registered_model_name="sk-learn-random-forest-reg-model",
)
示例输出
Successfully registered model 'sk-learn-random-forest-reg-model'.
Created version '1' of model 'sk-learn-random-forest-reg-model'.

太好了!我们已经注册了一个模型。

在继续之前,让我们强调一些重要的实现注意事项。

  • 要注册模型,您可以使用mlflow.sklearn.log_model()中的registered_model_name参数, 或者在记录模型后调用mlflow.register_model()。通常我们建议使用前者,因为它更加简洁。
  • Model Signatures 为我们的模型输入和输出提供验证。log_model()中的input_example 会自动推断并记录签名。我们再次建议使用这个实现,因为 它很简洁。