如何保存和加载试验检查点#

试验检查点是 Tune 存储的三种数据类型之一。这些是由用户定义的,旨在快照您的训练进度!

试验级别的检查点通过 Tune Trainable API 保存:这是你定义自定义训练逻辑的地方,也是你定义要检查点的试验状态的地方。在本指南中,我们将展示如何为 Tune 的函数训练和类训练 API 保存和加载检查点,并引导你了解配置选项。

函数 API 检查点#

如果使用 Ray Tune 的函数 API,可以按以下方式保存和加载检查点。要创建检查点,请使用 from_directory() API。

import os
import tempfile

from ray import train, tune
from ray.train import Checkpoint


def train_func(config):
    start = 1
    my_model = MyModel()

    checkpoint = train.get_checkpoint()
    if checkpoint:
        with checkpoint.as_directory() as checkpoint_dir:
            checkpoint_dict = torch.load(os.path.join(checkpoint_dir, "checkpoint.pt"))
            start = checkpoint_dict["epoch"] + 1
            my_model.load_state_dict(checkpoint_dict["model_state"])

    for epoch in range(start, config["epochs"] + 1):
        # Model training here
        # ...

        metrics = {"metric": 1}
        with tempfile.TemporaryDirectory() as tempdir:
            torch.save(
                {"epoch": epoch, "model_state": my_model.state_dict()},
                os.path.join(tempdir, "checkpoint.pt"),
            )
            train.report(metrics=metrics, checkpoint=Checkpoint.from_directory(tempdir))


tuner = tune.Tuner(train_func, param_space={"epochs": 5})
result_grid = tuner.fit()

在上面的代码片段中:

  • 我们通过 train.report(..., checkpoint=checkpoint) 实现 检查点保存 。请注意,每个检查点必须与一组指标一起报告——这样,检查点可以根据指定的指标进行排序。

  • 训练迭代 epoch 期间保存的检查点被保存到训练节点上的路径 <storage_path>/<exp_name>/<trial_name>/checkpoint_<epoch>,并可以根据 存储配置 进一步同步到统一的存储位置。

  • 我们通过 train.get_checkpoint() 实现 检查点加载 。每当 Tune 恢复一个试验时,这将填充试验的最新检查点。这种情况发生在 (1) 试验在遇到故障后被配置为重试,(2) 实验正在恢复,以及 (3) 试验在暂停后继续进行(例如:PBT)。

备注

checkpoint_frequencycheckpoint_at_end 在函数API检查点中不起作用。这些是通过函数可训练对象手动配置的。例如,如果你想每三个周期检查点一次,你可以这样做:

NUM_EPOCHS = 12
# checkpoint every three epochs.
CHECKPOINT_FREQ = 3


def train_func(config):
    for epoch in range(1, config["epochs"] + 1):
        # Model training here
        # ...

        # Report metrics and save a checkpoint
        metrics = {"metric": "my_metric"}
        if epoch % CHECKPOINT_FREQ == 0:
            with tempfile.TemporaryDirectory() as tempdir:
                # Save a checkpoint in tempdir.
                train.report(metrics, checkpoint=Checkpoint.from_directory(tempdir))
        else:
            train.report(metrics)


tuner = tune.Tuner(train_func, param_space={"epochs": NUM_EPOCHS})
result_grid = tuner.fit()

查看 这里以获取更多关于创建检查点的信息

类 API 检查点#

你也可以使用 Trainable 类 API 实现检查点/恢复:

import os
import torch
from torch import nn

from ray import train, tune


class MyTrainableClass(tune.Trainable):
    def setup(self, config):
        self.model = nn.Sequential(
            nn.Linear(config.get("input_size", 32), 32), nn.ReLU(), nn.Linear(32, 10)
        )

    def step(self):
        return {}

    def save_checkpoint(self, tmp_checkpoint_dir):
        checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth")
        torch.save(self.model.state_dict(), checkpoint_path)
        return tmp_checkpoint_dir

    def load_checkpoint(self, tmp_checkpoint_dir):
        checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth")
        self.model.load_state_dict(torch.load(checkpoint_path))


tuner = tune.Tuner(
    MyTrainableClass,
    param_space={"input_size": 64},
    run_config=train.RunConfig(
        stop={"training_iteration": 2},
        checkpoint_config=train.CheckpointConfig(checkpoint_frequency=2),
    ),
)
tuner.fit()

你可以通过三种不同的机制进行检查点操作:手动、定期和在终止时。

手动检查点#

自定义的可训练对象可以通过在 step 的结果字典中返回 should_checkpoint: True (或 tune.result.SHOULD_CHECKPOINT: True) 来手动触发检查点。这在竞价实例中特别有用:

import random


# to be implemented by user.
def detect_instance_preemption():
    choice = random.randint(1, 100)
    # simulating a 1% chance of preemption.
    return choice <= 1


def train_func(self):
    # training code
    result = {"mean_accuracy": "my_accuracy"}
    if detect_instance_preemption():
        result.update(should_checkpoint=True)
    return result


在上面的例子中,如果 detect_instance_preemption 返回 True,可以触发手动检查点。

定期检查点#

可以通过设置 checkpoint_frequency=N 来每 N 次迭代检查点试验,例如:


tuner = tune.Tuner(
    MyTrainableClass,
    run_config=train.RunConfig(
        stop={"training_iteration": 2},
        checkpoint_config=train.CheckpointConfig(checkpoint_frequency=10),
    ),
)
tuner.fit()

终止时的检查点#

checkpoint_frequency 可能不会与实验的精确结束时间一致。如果你想在试验结束时创建一个检查点,你可以额外设置 checkpoint_at_end=True

tuner = tune.Tuner(
    MyTrainableClass,
    run_config=train.RunConfig(
        stop={"training_iteration": 2},
        checkpoint_config=train.CheckpointConfig(
            checkpoint_frequency=10, checkpoint_at_end=True
        ),
    ),
)
tuner.fit()

配置#

可以通过 CheckpointConfig 配置检查点。由于检查点频率是在用户定义的训练循环中手动确定的,因此某些配置不适用于函数训练 API。请参见下面的兼容性矩阵。

类 API

函数 API

num_to_keep

checkpoint_score_attribute

checkpoint_score_order

checkpoint_frequency

checkpoint_at_end

摘要#

在本用户指南中,我们介绍了如何在 Tune 中保存和加载试验检查点。一旦启用了检查点功能,请继续阅读以下指南之一,了解如何:

附录:Tune 存储的数据类型#

实验检查点#

实验级别的检查点保存实验状态。这包括搜索器的状态、试验列表及其状态(例如,PENDING、RUNNING、TERMINATED、ERROR),以及与每个试验相关的元数据(例如,超参数配置、一些派生的试验结果(最小值、最大值、最后值)等)。

实验级别的检查点由驱动程序在头节点上定期保存。默认情况下,保存频率会自动调整,以便最多花费5%的时间来保存实验检查点,其余时间用于处理训练结果和调度。这个时间也可以通过 TUNE_GLOBAL_CHECKPOINT_S 环境变量 进行调整。

试验检查点#

试验级别的检查点捕获每个试验的状态。这通常包括模型和优化器的状态。以下是试验检查点的一些用途:

  • 如果由于某种原因(例如,在按需实例上)中断试验,可以从最后的状态恢复。不会丢失训练时间。

  • 一些搜索器或调度器会暂停试验,以便为其他试验腾出资源进行训练。如果试验可以从最新状态继续训练,这样做才有意义。

  • 检查点可以用于其他下游任务,如批量推理。

学习如何保存和加载试验检查点 这里

试验结果#

试验报告的指标会被保存并记录到各自的试验目录中。这些数据以CSV、JSON或Tensorboard(events.out.tfevents.*)格式存储,可以通过Tensorboard进行检查,并用于实验后的分析。