mlflow.pytorch

mlflow.pytorch 模块提供了一个用于记录和加载 PyTorch 模型的 API。该模块以以下格式导出 PyTorch 模型:

PyTorch (原生) 格式

这是可以重新加载回 PyTorch 的主要格式。

mlflow.pyfunc

为基于通用 pyfunc 的部署工具和批量推理而生成。

mlflow.pytorch.autolog(log_every_n_epoch=1, log_every_n_step=None, log_models=True, log_datasets=True, disable=False, exclusive=False, disable_for_unsupported_versions=False, silent=False, registered_model_name=None, extra_tags=None, checkpoint=True, checkpoint_monitor='val_loss', checkpoint_mode='min', checkpoint_save_best_only=True, checkpoint_save_weights_only=False, checkpoint_save_freq='epoch')[源代码]

备注

Autologging 已知与以下包版本兼容:1.9.0 <= torch <= 2.4.1。当使用此范围之外的包版本时,Autologging 可能无法成功。

启用(或禁用)并配置从 PyTorch Lightning 到 MLflow 的自动日志记录。

当你调用 pytorch_lightning.Trainer()fit 方法时,会执行自动日志记录。

探索完整的 PyTorch MNIST ,以获取包含额外轻量化步骤实现的广泛示例。

注意:完全自动记录仅支持 PyTorch Lightning 模型,即那些继承自 pytorch_lightning.LightningModule 的模型。对于原生 PyTorch 的自动记录支持(即仅继承自 torch.nn.Module 的模型)仅自动记录对 torch.utils.tensorboard.SummaryWriteradd_scalaradd_hparams 方法的调用至 mlflow。在这种情况下,也不存在“epoch”的概念。

参数:
  • log_every_n_epoch – 如果指定,每 n 个周期记录一次指标。默认情况下,每个周期结束后记录指标。

  • log_every_n_step – 如果指定,每 n 个训练步骤记录一次批次指标。默认情况下,步骤的指标不会被记录。请注意,将其设置为 1 可能会导致性能问题,不推荐这样做。指标是根据 Lightning 的全局步骤号记录的,当使用多个优化器时,假设每个训练步骤中所有优化器都会被更新。

  • log_models – 如果 True,训练的模型会被记录为 MLflow 模型工件。如果 False,训练的模型不会被记录。

  • log_datasets – 如果 True,数据集信息将被记录到 MLflow 跟踪中。如果 False,数据集信息将不会被记录。

  • disable – 如果 True,禁用 PyTorch Lightning 的自动日志集成。如果 False,启用 PyTorch Lightning 的自动日志集成。

  • exclusive – 如果 True ,自动记录的内容不会记录到用户创建的 fluent 运行中。如果 False ,自动记录的内容会记录到活动的 fluent 运行中,这可能是用户创建的。

  • disable_for_unsupported_versions – 如果 True,则对未经过此版本 MLflow 客户端测试或不兼容的 pytorch 和 pytorch-lightning 版本禁用自动日志记录。

  • silent – 如果 True,在 PyTorch Lightning 自动日志记录期间,抑制 MLflow 的所有事件日志和警告。如果 False,在 PyTorch Lightning 自动日志记录期间显示所有事件和警告。

  • registered_model_name – 如果提供,每次训练模型时,它都会被注册为具有此名称的已注册模型的新的模型版本。如果该注册模型尚不存在,则会创建它。

  • extra_tags – 一个字典,包含要为 autologging 创建的每个托管运行设置的额外标签。

  • checkpoint – 启用自动模型检查点功能,此功能仅支持 pytorch-lightning >= 1.6.0。

  • checkpoint_monitor – 在自动模型检查点保存中,如果你将 model_checkpoint_save_best_only 设置为 True,则要监控的指标名称。

  • checkpoint_save_best_only – 如果为 True,自动模型检查点仅在模型被认为是根据监控量和之前的检查点模型被覆盖的“最佳”模型时保存。

  • checkpoint_mode – one of {“min”, “max”}。在自动模型检查点保存中,如果 save_best_only=True,则根据监视量的最大化或最小化来决定是否覆盖当前保存文件。

  • checkpoint_save_weights_only – 在自动模型检查点保存中,如果为 True,则仅保存模型的权重。否则,优化器状态、学习率调度器状态等也会添加到检查点中。

  • checkpoint_save_freq“epoch” 或整数。当使用 “epoch” 时,回调函数在每个 epoch 后保存模型。当使用整数时,回调函数在此数量的批次结束时保存模型。请注意,如果保存与 epoch 不对齐,监控的指标可能不太可靠(因为它可能只反映一个批次,因为指标在每个 epoch 都会重置)。默认为 “epoch”

示例
import os

import lightning as L
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader, Subset
from torchmetrics import Accuracy
from torchvision import transforms
from torchvision.datasets import MNIST

import mlflow.pytorch
from mlflow import MlflowClient


class MNISTModel(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.l1 = torch.nn.Linear(28 * 28, 10)
        self.accuracy = Accuracy("multiclass", num_classes=10)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_nb):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        pred = logits.argmax(dim=1)
        acc = self.accuracy(pred, y)

        # PyTorch `self.log` will be automatically captured by MLflow.
        self.log("train_loss", loss, on_epoch=True)
        self.log("acc", acc, on_epoch=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)


def print_auto_logged_info(r):
    tags = {k: v for k, v in r.data.tags.items() if not k.startswith("mlflow.")}
    artifacts = [f.path for f in MlflowClient().list_artifacts(r.info.run_id, "model")]
    print(f"run_id: {r.info.run_id}")
    print(f"artifacts: {artifacts}")
    print(f"params: {r.data.params}")
    print(f"metrics: {r.data.metrics}")
    print(f"tags: {tags}")


# Initialize our model.
mnist_model = MNISTModel()

# Load MNIST dataset.
train_ds = MNIST(
    os.getcwd(), train=True, download=True, transform=transforms.ToTensor()
)
# Only take a subset of the data for faster training.
indices = torch.arange(32)
train_ds = Subset(train_ds, indices)
train_loader = DataLoader(train_ds, batch_size=8)

# Initialize a trainer.
trainer = L.Trainer(max_epochs=3)

# Auto log all MLflow entities
mlflow.pytorch.autolog()

# Train the model.
with mlflow.start_run() as run:
    trainer.fit(mnist_model, train_loader)

# Fetch the auto logged parameters and metrics.
print_auto_logged_info(mlflow.get_run(run_id=run.info.run_id))
mlflow.pytorch.get_default_conda_env()[源代码]
返回:

默认的 Conda 环境作为字典,用于由调用 save_model()log_model() 生成的 MLflow 模型。

示例
import mlflow

# Log PyTorch model
with mlflow.start_run() as run:
    mlflow.pytorch.log_model(model, "model", signature=signature)

# Fetch the associated conda environment
env = mlflow.pytorch.get_default_conda_env()
print(f"conda env: {env}")
输出
conda env {'name': 'mlflow-env',
           'channels': ['conda-forge'],
           'dependencies': ['python=3.8.15',
                            {'pip': ['torch==1.5.1',
                                     'mlflow',
                                     'cloudpickle==1.6.0']}]}
mlflow.pytorch.get_default_pip_requirements()[源代码]
返回:

此flavor生成的MLflow Models的默认pip需求列表。对 save_model()log_model() 的调用会生成一个pip环境,该环境至少包含这些需求。

mlflow.pytorch.load_checkpoint(model_class, run_id=None, epoch=None, global_step=None, kwargs=None)[源代码]

如果在 autologging 中启用 “checkpoint”,在 pytorch-lightning 模型训练执行期间,检查点模型会被记录为 MLflow 工件。使用此 API,您可以加载检查点模型。

如果你想加载最新的检查点,将 epochglobal_step 都设置为 None。如果在自动记录中将 checkpoint_save_freq 设置为 epoch,你可以将 epoch 参数设置为要加载的检查点的 epoch 来加载特定 epoch 的检查点。如果在自动记录中将 checkpoint_save_freq 设置为一个整数,你可以将 global_step 参数设置为要加载的检查点的全局步数来加载特定全局步数的检查点。epoch 参数和 global_step 不能同时设置。

参数:
  • model_class – 训练模型的类,该类应继承 ‘pytorch_lightning.LightningModule’。

  • run_id – 模型日志记录到的运行的ID。如果未提供,则使用当前活动的运行。

  • epoch – 要加载的检查点的时期,如果你将 “checkpoint_save_freq” 设置为 “epoch”。

  • global_step – 如果要加载的检查点的全局步骤,如果你将“checkpoint_save_freq”设置为一个整数。

  • kwargs – 初始化模型所需的任何额外关键字参数。

返回:

从指定检查点恢复的 pytorch-lightning 模型实例。

示例
import mlflow

mlflow.pytorch.autolog(checkpoint=True)

model = MyLightningModuleNet()  # A custom-pytorch lightning model
train_loader = create_train_dataset_loader()
trainer = Trainer()

with mlflow.start_run() as run:
    trainer.fit(model, train_loader)

run_id = run.info.run_id

# load latest checkpoint model
latest_checkpoint_model = mlflow.pytorch.load_checkpoint(MyLightningModuleNet, run_id)

# load history checkpoint model logged in second epoch
checkpoint_model = mlflow.pytorch.load_checkpoint(MyLightningModuleNet, run_id, epoch=2)
mlflow.pytorch.load_model(model_uri, dst_path=None, **kwargs)[源代码]

从本地文件或运行中加载 PyTorch 模型。

参数:
  • model_uri – MLflow 模型的位置,采用 URI 格式,例如: - /Users/me/path/to/local/model - relative/path/to/local/model - s3://my_bucket/path/to/model - runs:/<mlflow_run_id>/run-relative/path/to/model - models:/<model_name>/<model_version> - models:/<model_name>/<stage> 有关支持的 URI 方案的更多信息,请参阅 引用工件

  • dst_path – 下载模型工件的本地文件系统路径。此目录必须已经存在。如果未指定,将创建一个本地输出路径。

  • kwargs – 传递给 torch.load 方法的关键字参数。

返回:

一个 PyTorch 模型。

示例
import torch
import mlflow.pytorch


model = nn.Linear(1, 1)

# Log the model
with mlflow.start_run() as run:
    mlflow.pytorch.log_model(model, "model")

# Inference after loading the logged model
model_uri = f"runs:/{run.info.run_id}/model"
loaded_model = mlflow.pytorch.load_model(model_uri)
for x in [4.0, 6.0, 30.0]:
    X = torch.Tensor([[x]])
    y_pred = loaded_model(X)
    print(f"predict X: {x}, y_pred: {y_pred.data.item():.2f}")
输出
predict X: 4.0, y_pred: 7.57
predict X: 6.0, y_pred: 11.64
predict X: 30.0, y_pred: 60.48
mlflow.pytorch.log_model(pytorch_model, artifact_path, conda_env=None, code_paths=None, pickle_module=None, registered_model_name=None, signature: ModelSignature = None, input_example: DataFrame | ndarray | dict | list | csr_matrix | csc_matrix | str | bytes | tuple = None, await_registration_for=300, requirements_file=None, extra_files=None, pip_requirements=None, extra_pip_requirements=None, metadata=None, **kwargs)[源代码]

将 PyTorch 模型记录为当前运行的 MLflow 工件。

警告

记录模型时附上签名以避免推理错误。如果模型在未附签名的情况下被记录,MLflow 模型服务器依赖于从 NumPy 推断的默认数据类型。然而,PyTorch 通常期望不同的默认值,特别是在解析浮点数时。您必须包含签名以确保模型以正确的数据类型记录,从而使 MLflow 模型服务器能够正确提供有效的输入。

参数:
  • pytorch_model

    要保存的 PyTorch 模型。可以是 eager 模型(torch.nn.Module 的子类)或通过 torch.jit.scripttorch.jit.trace 准备的脚本模型。

    该模型接受一个 torch.FloatTensor 作为输入,并生成一个输出张量。

    如果保存一个 eager 模型,模型的类及其所有代码依赖项,包括类定义本身,应包含在以下位置之一:

    • 模型 Conda 环境中列出的包,由 conda_env 参数指定。

    • code_paths 参数指定的文件中的一个或多个。

  • artifact_path – 运行相对的工件路径。

  • conda_env

    一个Conda环境的字典表示形式,或Conda环境yaml文件的路径。如果提供,这将描述模型应运行的环境。至少,它应指定包含在 get_default_conda_env() 中的依赖项。如果为 None,则通过 mlflow.models.infer_pip_requirements() 推断的pip要求添加一个conda环境到模型中。如果要求推断失败,则回退到使用 get_default_pip_requirements()。来自 conda_env 的pip要求被写入一个pip requirements.txt 文件,完整的conda环境被写入 conda.yaml。以下是一个conda环境的字典表示形式的*示例*:

    {
        "name": "mlflow-env",
        "channels": ["conda-forge"],
        "dependencies": [
            "python=3.8.15",
            {
                "pip": [
                    "torch==x.y.z"
                ],
            },
        ],
    }
    

  • code_paths – 本地文件系统路径列表,指向Python文件依赖项(或包含文件依赖项的目录)。这些文件在加载模型时会被*前置*到系统路径中。如果为给定模型声明了依赖关系,则应从公共根路径声明相对导入,以避免在加载模型时出现导入错误。有关``code_paths``功能的详细解释、推荐的使用模式和限制,请参阅`code_paths使用指南 <https://mlflow.org/docs/latest/model/dependencies.html?highlight=code_paths#saving-extra-code-with-an-mlflow-model>`_。

  • pickle_module – PyTorch 用于序列化(“pickle”)指定 pytorch_model 的模块。这作为 pickle_module 参数传递给 torch.save()。默认情况下,此模块也用于在加载时反序列化(“unpickle”)PyTorch 模型。

  • registered_model_name – 如果指定,在 registered_model_name 下创建一个模型版本,如果给定名称的注册模型不存在,则同时创建一个注册模型。

  • signature – 一个 ModelSignature 类的实例,描述了模型的输入和输出。如果没有指定但提供了 input_example,将根据提供的输入示例和模型自动推断签名。要在提供输入示例时禁用自动签名推断,请将 signature 设置为 False。要手动推断模型签名,请在具有有效模型输入(例如省略了目标列的训练数据集)和有效模型输出(例如在训练数据集上进行的模型预测)的数据集上调用 infer_signature(),例如:

  • input_example – 一个或多个有效的模型输入实例。输入示例用作提示,指示应向模型提供哪些数据。它将被转换为Pandas DataFrame,然后使用Pandas的面向分割的格式序列化为json,或者是一个numpy数组,其中示例将通过将其转换为列表来序列化为json。字节被base64编码。当``signature``参数为``None``时,输入示例用于推断模型签名。

  • await_registration_for – 等待模型版本完成创建并处于 READY 状态的秒数。默认情况下,函数等待五分钟。指定 0 或 None 以跳过等待。

  • requirements_file

    警告

    requirements_file 已被弃用。请改用 pip_requirements

    包含需求文件路径的字符串。远程URI会被解析为绝对文件系统路径。例如,考虑以下 requirements_file 字符串:

    requirements_file = "s3://my-bucket/path/to/my_file"
    

    在这种情况下,"my_file" 需求文件是从 S3 下载的。如果为 None,则不会向模型添加需求文件。

  • extra_files – 包含相应额外文件路径的列表,如果为 None ,则不会向模型添加额外文件。远程 URI 会被解析为绝对文件系统路径。例如,考虑以下 extra_files 列表: .. code-block:: python

  • pip_requirements – 可以是 pip 需求字符串的可迭代对象(例如 ["torch", "-r requirements.txt", "-c constraints.txt"]),或者是本地文件系统上 pip 需求文件的字符串路径(例如 "requirements.txt")。如果提供,这将描述该模型应运行的环境。如果为 None,则通过 mlflow.models.infer_pip_requirements() 从当前软件环境中推断默认的需求列表。如果需求推断失败,则回退到使用 get_default_pip_requirements()。需求和约束都会自动解析并分别写入 requirements.txtconstraints.txt 文件,并作为模型的一部分存储。需求也会写入模型 conda 环境(conda.yaml)文件的 pip 部分。

  • extra_pip_requirements – 可以是 pip 需求字符串的可迭代对象(例如 ["pandas", "-r requirements.txt", "-c constraints.txt"]),或者是本地文件系统上的 pip 需求文件的字符串路径(例如 "requirements.txt")。如果提供,这将描述附加的 pip 需求,这些需求会被追加到根据用户当前软件环境自动生成的一组默认 pip 需求中。需求和约束会分别自动解析并写入 requirements.txtconstraints.txt 文件,并作为模型的一部分存储。需求也会被写入模型的 conda 环境(conda.yaml)文件的 pip 部分。 .. 警告:: 以下参数不能同时指定: - conda_env - pip_requirements - extra_pip_requirements 这个示例 展示了如何使用 pip_requirementsextra_pip_requirements 指定 pip 需求。

  • metadata – 传递给模型并在 MLmodel 文件中存储的自定义元数据字典。

  • kwargs – 传递给 torch.save 方法的关键字参数。

返回:

一个包含记录模型元数据的 ModelInfo 实例。

示例
import numpy as np
import torch
import mlflow
from mlflow import MlflowClient
from mlflow.models import infer_signature

# Define model, loss, and optimizer
model = nn.Linear(1, 1)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

# Create training data with relationship y = 2X
X = torch.arange(1.0, 26.0).reshape(-1, 1)
y = X * 2

# Training loop
epochs = 250
for epoch in range(epochs):
    # Forward pass: Compute predicted y by passing X to the model
    y_pred = model(X)

    # Compute the loss
    loss = criterion(y_pred, y)

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

# Create model signature
signature = infer_signature(X.numpy(), model(X).detach().numpy())

# Log the model
with mlflow.start_run() as run:
    mlflow.pytorch.log_model(model, "model")

    # convert to scripted model and log the model
    scripted_pytorch_model = torch.jit.script(model)
    mlflow.pytorch.log_model(scripted_pytorch_model, "scripted_model")

# Fetch the logged model artifacts
print(f"run_id: {run.info.run_id}")
for artifact_path in ["model/data", "scripted_model/data"]:
    artifacts = [
        f.path for f in MlflowClient().list_artifacts(run.info.run_id, artifact_path)
    ]
    print(f"artifacts: {artifacts}")
输出
run_id: 1a1ec9e413ce48e9abf9aec20efd6f71
artifacts: ['model/data/model.pth',
            'model/data/pickle_module_info.txt']
artifacts: ['scripted_model/data/model.pth',
            'scripted_model/data/pickle_module_info.txt']
../../_images/pytorch_logged_models.png

PyTorch 记录的模型

mlflow.pytorch.save_model(pytorch_model, path, conda_env=None, mlflow_model=None, code_paths=None, pickle_module=None, signature: ModelSignature = None, input_example: DataFrame | ndarray | dict | list | csr_matrix | csc_matrix | str | bytes | tuple = None, requirements_file=None, extra_files=None, pip_requirements=None, extra_pip_requirements=None, metadata=None, **kwargs)[源代码]

将 PyTorch 模型保存到本地文件系统中的路径。

参数:
  • pytorch_model

    要保存的 PyTorch 模型。可以是 eager 模型(torch.nn.Module 的子类)或通过 torch.jit.scripttorch.jit.trace 准备的脚本模型。

    要保存一个eager模型,模型的类及其所有代码依赖项,包括类定义本身,应包含在以下位置之一:

    • 模型 Conda 环境中列出的包,由 conda_env 参数指定。

    • code_paths 参数指定的文件中的一个或多个。

  • path – 模型保存的本地路径。

  • conda_env

    一个Conda环境的字典表示形式,或Conda环境yaml文件的路径。如果提供,这将描述模型应运行的环境。至少,它应指定包含在 get_default_conda_env() 中的依赖项。如果为 None,则通过 mlflow.models.infer_pip_requirements() 推断的pip要求添加一个conda环境到模型中。如果要求推断失败,则回退到使用 get_default_pip_requirements()。来自 conda_env 的pip要求被写入一个pip requirements.txt 文件,完整的conda环境被写入 conda.yaml。以下是一个conda环境的字典表示形式的*示例*:

    {
        "name": "mlflow-env",
        "channels": ["conda-forge"],
        "dependencies": [
            "python=3.8.15",
            {
                "pip": [
                    "torch==x.y.z"
                ],
            },
        ],
    }
    

  • mlflow_modelmlflow.models.Model 正在添加此风格。

  • code_paths – 本地文件系统路径列表,指向Python文件依赖项(或包含文件依赖项的目录)。这些文件在加载模型时会被*前置*到系统路径中。如果为给定模型声明了依赖关系,则应从公共根路径声明相对导入,以避免在加载模型时出现导入错误。有关``code_paths``功能的详细解释、推荐的使用模式和限制,请参阅`code_paths使用指南 <https://mlflow.org/docs/latest/model/dependencies.html?highlight=code_paths#saving-extra-code-with-an-mlflow-model>`_。

  • pickle_module – PyTorch 用于序列化(“pickle”)指定 pytorch_model 的模块。这作为 pickle_module 参数传递给 torch.save()。默认情况下,此模块也用于在加载时反序列化(“unpickle”)模型。

  • signature – 一个 ModelSignature 类的实例,描述了模型的输入和输出。如果没有指定但提供了 input_example,将根据提供的输入示例和模型自动推断签名。要在提供输入示例时禁用自动签名推断,请将 signature 设置为 False。要手动推断模型签名,请在具有有效模型输入(例如省略了目标列的训练数据集)和有效模型输出(例如在训练数据集上进行的模型预测)的数据集上调用 infer_signature(),例如:

  • input_example – 一个或多个有效的模型输入实例。输入示例用作提示,指示应向模型提供哪些数据。它将被转换为Pandas DataFrame,然后使用Pandas的面向分割的格式序列化为json,或者是一个numpy数组,其中示例将通过将其转换为列表来序列化为json。字节被base64编码。当``signature``参数为``None``时,输入示例用于推断模型签名。

  • requirements_file

    警告

    requirements_file 已被弃用。请改用 pip_requirements

    包含需求文件路径的字符串。远程URI会被解析为绝对文件系统路径。例如,考虑以下 requirements_file 字符串:

    requirements_file = "s3://my-bucket/path/to/my_file"
    

    在这种情况下,"my_file" 需求文件是从 S3 下载的。如果为 None,则不会向模型添加需求文件。

  • extra_files – 包含相应额外文件路径的列表。远程URI会被解析为绝对文件系统路径。例如,考虑以下 extra_files 列表 - extra_files = [“s3://my-bucket/path/to/my_file1”, “s3://my-bucket/path/to/my_file2”] 在这种情况下,"my_file1 & my_file2" 额外文件会从S3下载。如果为 None ,则不会向模型添加额外文件。

  • pip_requirements – 可以是 pip 需求字符串的可迭代对象(例如 ["torch", "-r requirements.txt", "-c constraints.txt"]),或者是本地文件系统上 pip 需求文件的字符串路径(例如 "requirements.txt")。如果提供,这将描述该模型应运行的环境。如果为 None,则通过 mlflow.models.infer_pip_requirements() 从当前软件环境中推断默认的需求列表。如果需求推断失败,则回退到使用 get_default_pip_requirements()。需求和约束都会自动解析并分别写入 requirements.txtconstraints.txt 文件,并作为模型的一部分存储。需求也会写入模型 conda 环境(conda.yaml)文件的 pip 部分。

  • extra_pip_requirements – 可以是 pip 需求字符串的可迭代对象(例如 ["pandas", "-r requirements.txt", "-c constraints.txt"]),或者是本地文件系统上的 pip 需求文件的字符串路径(例如 "requirements.txt")。如果提供,这将描述附加的 pip 需求,这些需求会被追加到根据用户当前软件环境自动生成的一组默认 pip 需求中。需求和约束会分别自动解析并写入 requirements.txtconstraints.txt 文件,并作为模型的一部分存储。需求也会被写入模型的 conda 环境(conda.yaml)文件的 pip 部分。 .. 警告:: 以下参数不能同时指定: - conda_env - pip_requirements - extra_pip_requirements 这个示例 展示了如何使用 pip_requirementsextra_pip_requirements 指定 pip 需求。

  • metadata – 传递给模型并在 MLmodel 文件中存储的自定义元数据字典。

  • kwargs – 传递给 torch.save 方法的关键字参数。

示例
import os
import mlflow
import torch


model = nn.Linear(1, 1)

# Save PyTorch models to current working directory
with mlflow.start_run() as run:
    mlflow.pytorch.save_model(model, "model")

    # Convert to a scripted model and save it
    scripted_pytorch_model = torch.jit.script(model)
    mlflow.pytorch.save_model(scripted_pytorch_model, "scripted_model")

# Load each saved model for inference
for model_path in ["model", "scripted_model"]:
    model_uri = f"{os.getcwd()}/{model_path}"
    loaded_model = mlflow.pytorch.load_model(model_uri)
    print(f"Loaded {model_path}:")
    for x in [6.0, 8.0, 12.0, 30.0]:
        X = torch.Tensor([[x]])
        y_pred = loaded_model(X)
        print(f"predict X: {x}, y_pred: {y_pred.data.item():.2f}")
    print("--")
输出
Loaded model:
predict X: 6.0, y_pred: 11.90
predict X: 8.0, y_pred: 15.92
predict X: 12.0, y_pred: 23.96
predict X: 30.0, y_pred: 60.13
--
Loaded scripted_model:
predict X: 6.0, y_pred: 11.90
predict X: 8.0, y_pred: 15.92
predict X: 12.0, y_pred: 23.96
predict X: 30.0, y_pred: 60.13