如何保存和加载试验检查点#
试验检查点是 Tune 存储的三种数据类型之一。这些是由用户定义的,旨在快照您的训练进度!
试验级别的检查点通过 Tune Trainable API 保存:这是你定义自定义训练逻辑的地方,也是你定义要检查点的试验状态的地方。在本指南中,我们将展示如何为 Tune 的函数训练和类训练 API 保存和加载检查点,并引导你了解配置选项。
函数 API 检查点#
如果使用 Ray Tune 的函数 API,可以按以下方式保存和加载检查点。要创建检查点,请使用 from_directory()
API。
import os
import tempfile
from ray import train, tune
from ray.train import Checkpoint
def train_func(config):
start = 1
my_model = MyModel()
checkpoint = train.get_checkpoint()
if checkpoint:
with checkpoint.as_directory() as checkpoint_dir:
checkpoint_dict = torch.load(os.path.join(checkpoint_dir, "checkpoint.pt"))
start = checkpoint_dict["epoch"] + 1
my_model.load_state_dict(checkpoint_dict["model_state"])
for epoch in range(start, config["epochs"] + 1):
# Model training here
# ...
metrics = {"metric": 1}
with tempfile.TemporaryDirectory() as tempdir:
torch.save(
{"epoch": epoch, "model_state": my_model.state_dict()},
os.path.join(tempdir, "checkpoint.pt"),
)
train.report(metrics=metrics, checkpoint=Checkpoint.from_directory(tempdir))
tuner = tune.Tuner(train_func, param_space={"epochs": 5})
result_grid = tuner.fit()
在上面的代码片段中:
我们通过
train.report(..., checkpoint=checkpoint)
实现 检查点保存 。请注意,每个检查点必须与一组指标一起报告——这样,检查点可以根据指定的指标进行排序。训练迭代
epoch
期间保存的检查点被保存到训练节点上的路径<storage_path>/<exp_name>/<trial_name>/checkpoint_<epoch>
,并可以根据 存储配置 进一步同步到统一的存储位置。我们通过
train.get_checkpoint()
实现 检查点加载 。每当 Tune 恢复一个试验时,这将填充试验的最新检查点。这种情况发生在 (1) 试验在遇到故障后被配置为重试,(2) 实验正在恢复,以及 (3) 试验在暂停后继续进行(例如:PBT)。
备注
checkpoint_frequency
和 checkpoint_at_end
在函数API检查点中不起作用。这些是通过函数可训练对象手动配置的。例如,如果你想每三个周期检查点一次,你可以这样做:
NUM_EPOCHS = 12
# checkpoint every three epochs.
CHECKPOINT_FREQ = 3
def train_func(config):
for epoch in range(1, config["epochs"] + 1):
# Model training here
# ...
# Report metrics and save a checkpoint
metrics = {"metric": "my_metric"}
if epoch % CHECKPOINT_FREQ == 0:
with tempfile.TemporaryDirectory() as tempdir:
# Save a checkpoint in tempdir.
train.report(metrics, checkpoint=Checkpoint.from_directory(tempdir))
else:
train.report(metrics)
tuner = tune.Tuner(train_func, param_space={"epochs": NUM_EPOCHS})
result_grid = tuner.fit()
类 API 检查点#
你也可以使用 Trainable 类 API 实现检查点/恢复:
import os
import torch
from torch import nn
from ray import train, tune
class MyTrainableClass(tune.Trainable):
def setup(self, config):
self.model = nn.Sequential(
nn.Linear(config.get("input_size", 32), 32), nn.ReLU(), nn.Linear(32, 10)
)
def step(self):
return {}
def save_checkpoint(self, tmp_checkpoint_dir):
checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth")
torch.save(self.model.state_dict(), checkpoint_path)
return tmp_checkpoint_dir
def load_checkpoint(self, tmp_checkpoint_dir):
checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth")
self.model.load_state_dict(torch.load(checkpoint_path))
tuner = tune.Tuner(
MyTrainableClass,
param_space={"input_size": 64},
run_config=train.RunConfig(
stop={"training_iteration": 2},
checkpoint_config=train.CheckpointConfig(checkpoint_frequency=2),
),
)
tuner.fit()
你可以通过三种不同的机制进行检查点操作:手动、定期和在终止时。
手动检查点#
自定义的可训练对象可以通过在 step
的结果字典中返回 should_checkpoint: True
(或 tune.result.SHOULD_CHECKPOINT: True
) 来手动触发检查点。这在竞价实例中特别有用:
import random
# to be implemented by user.
def detect_instance_preemption():
choice = random.randint(1, 100)
# simulating a 1% chance of preemption.
return choice <= 1
def train_func(self):
# training code
result = {"mean_accuracy": "my_accuracy"}
if detect_instance_preemption():
result.update(should_checkpoint=True)
return result
在上面的例子中,如果 detect_instance_preemption
返回 True,可以触发手动检查点。
定期检查点#
可以通过设置 checkpoint_frequency=N
来每 N 次迭代检查点试验,例如:
tuner = tune.Tuner(
MyTrainableClass,
run_config=train.RunConfig(
stop={"training_iteration": 2},
checkpoint_config=train.CheckpointConfig(checkpoint_frequency=10),
),
)
tuner.fit()
终止时的检查点#
checkpoint_frequency 可能不会与实验的精确结束时间一致。如果你想在试验结束时创建一个检查点,你可以额外设置 checkpoint_at_end=True
:
tuner = tune.Tuner(
MyTrainableClass,
run_config=train.RunConfig(
stop={"training_iteration": 2},
checkpoint_config=train.CheckpointConfig(
checkpoint_frequency=10, checkpoint_at_end=True
),
),
)
tuner.fit()
配置#
可以通过 CheckpointConfig
配置检查点。由于检查点频率是在用户定义的训练循环中手动确定的,因此某些配置不适用于函数训练 API。请参见下面的兼容性矩阵。
类 API |
函数 API |
|
---|---|---|
|
✅ |
✅ |
|
✅ |
✅ |
|
✅ |
✅ |
|
✅ |
❌ |
|
✅ |
❌ |
摘要#
在本用户指南中,我们介绍了如何在 Tune 中保存和加载试验检查点。一旦启用了检查点功能,请继续阅读以下指南之一,了解如何:
附录:Tune 存储的数据类型#
实验检查点#
实验级别的检查点保存实验状态。这包括搜索器的状态、试验列表及其状态(例如,PENDING、RUNNING、TERMINATED、ERROR),以及与每个试验相关的元数据(例如,超参数配置、一些派生的试验结果(最小值、最大值、最后值)等)。
实验级别的检查点由驱动程序在头节点上定期保存。默认情况下,保存频率会自动调整,以便最多花费5%的时间来保存实验检查点,其余时间用于处理训练结果和调度。这个时间也可以通过 TUNE_GLOBAL_CHECKPOINT_S 环境变量 进行调整。
试验检查点#
试验级别的检查点捕获每个试验的状态。这通常包括模型和优化器的状态。以下是试验检查点的一些用途:
如果由于某种原因(例如,在按需实例上)中断试验,可以从最后的状态恢复。不会丢失训练时间。
一些搜索器或调度器会暂停试验,以便为其他试验腾出资源进行训练。如果试验可以从最新状态继续训练,这样做才有意义。
检查点可以用于其他下游任务,如批量推理。
学习如何保存和加载试验检查点 这里。
试验结果#
试验报告的指标会被保存并记录到各自的试验目录中。这些数据以CSV、JSON或Tensorboard(events.out.tfevents.*)格式存储,可以通过Tensorboard进行检查,并用于实验后的分析。