保存和加载检查点#

Ray Train 提供了一种使用 检查点 来快照训练进度的方式。

这对于以下方面很有用:

  1. 存储表现最佳的模型权重: 将您的模型保存到持久存储中,并将其用于下游服务/推理。

  2. 容错性: 在预先可抢占的机器/pod集群上处理长时间运行训练任务中的节点故障。

  3. 分布式检查点: 在进行 模型并行训练 时,Ray Train 检查点提供了一种简单的方式来 并行上传每个工作节点的模型分片,而无需将完整模型收集到单个节点。

  4. 与 Ray Tune 集成: 某些 Ray Tune 调度器 需要检查点保存和加载。

在训练过程中保存检查点#

Ray Train 提供的 Checkpoint 是一个轻量级接口,表示存在于本地或远程存储上的 目录

例如,一个检查点可以指向云存储中的一个目录:s3://my-bucket/my-checkpoint-dir。一个本地可用的检查点指向本地文件系统中的一个位置:/tmp/my-checkpoint-dir

以下是如何在训练循环中保存检查点:

  1. 将模型检查点写入本地目录。

    • 由于 Checkpoint 仅指向一个目录,因此其内容完全由您决定。

    • 这意味着你可以使用任何你想要的序列化格式。

    • 这使得 使用训练框架提供的熟悉的检查点工具变得容易,例如 torch.savepl.Trainer.save_checkpoint、Accelerate 的 accelerator.save_model、Transformers 的 save_pretrainedtf.keras.Model.save 等。

  2. 使用 Checkpoint.from_directory 从目录创建一个 Checkpoint

  3. 使用 ray.train.report(metrics, checkpoint=...) 向 Ray Train 报告检查点。

    • 与检查点一起报告的指标用于 跟踪表现最佳的检查点

    • 如果配置了持久存储,这将 上传检查点到持久存储 。请参阅 持久存储指南

../../_images/checkpoint_lifecycle.png

一个 Checkpoint 的生命周期,从本地磁盘保存到通过 train.report 上传到持久存储。#

如上图所示,保存检查点的最佳实践是首先将检查点转储到本地临时目录。然后,调用 train.report 将检查点上传到其最终的持久存储位置。之后,可以安全地清理本地临时目录以释放磁盘空间(例如,通过退出 tempfile.TemporaryDirectory 上下文)。

小技巧

在标准DDP训练中,每个工作节点都有完整的模型副本,您应该只从一个工作节点保存和报告检查点,以防止冗余上传。

这通常看起来像:

import tempfile

from ray import train


def train_fn(config):
    ...

    metrics = {...}
    with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
        checkpoint = None

        # Only the global rank 0 worker saves and reports the checkpoint
        if train.get_context().get_world_rank() == 0:
            ...  # Save checkpoint to temp_checkpoint_dir

            checkpoint = Checkpoint.from_directory(tmpdir)

        train.report(metrics, checkpoint=checkpoint)


如果使用如 DeepSpeed Zero-3 和 FSDP 这样的并行训练策略,其中每个工作节点只拥有完整模型的一部分,你应该从每个工作节点保存并报告一个检查点。参见 从多个工作节点保存检查点(分布式检查点) 示例。

以下是使用不同训练框架保存检查点的几个示例:

import os
import tempfile

import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam

import ray.train.torch
from ray import train
from ray.train import Checkpoint, ScalingConfig
from ray.train.torch import TorchTrainer


def train_func(config):
    n = 100
    # create a toy dataset
    # data   : X - dim = (n, 4)
    # target : Y - dim = (n, 1)
    X = torch.Tensor(np.random.normal(0, 1, size=(n, 4)))
    Y = torch.Tensor(np.random.uniform(0, 1, size=(n, 1)))
    # toy neural network : 1-layer
    # Wrap the model in DDP
    model = ray.train.torch.prepare_model(nn.Linear(4, 1))
    criterion = nn.MSELoss()

    optimizer = Adam(model.parameters(), lr=3e-4)
    for epoch in range(config["num_epochs"]):
        y = model.forward(X)
        loss = criterion(y, Y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        metrics = {"loss": loss.item()}

        with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
            checkpoint = None

            should_checkpoint = epoch % config.get("checkpoint_freq", 1) == 0
            # In standard DDP training, where the model is the same across all ranks,
            # only the global rank 0 worker needs to save and report the checkpoint
            if train.get_context().get_world_rank() == 0 and should_checkpoint:
                torch.save(
                    model.module.state_dict(),  # NOTE: Unwrap the model.
                    os.path.join(temp_checkpoint_dir, "model.pt"),
                )
                checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)

            train.report(metrics, checkpoint=checkpoint)


trainer = TorchTrainer(
    train_func,
    train_loop_config={"num_epochs": 5},
    scaling_config=ScalingConfig(num_workers=2),
)
result = trainer.fit()

小技巧

在保存模型到检查点之前,您很可能希望解包 DDP 模型。model.module.state_dict() 是状态字典,其中每个键都没有 "module." 前缀。

Ray Train 利用 PyTorch Lightning 的 Callback 接口来报告指标和检查点。我们提供了一个简单的回调实现,用于报告 on_train_epoch_end

具体来说,在每个训练周期结束时,它

import pytorch_lightning as pl

from ray import train
from ray.train.lightning import RayTrainReportCallback
from ray.train.torch import TorchTrainer


class MyLightningModule(pl.LightningModule):
    # ...

    def on_validation_epoch_end(self):
        ...
        mean_acc = calculate_accuracy()
        self.log("mean_accuracy", mean_acc, sync_dist=True)


def train_func():
    ...
    model = MyLightningModule(...)
    datamodule = MyLightningDataModule(...)

    trainer = pl.Trainer(
        # ...
        callbacks=[RayTrainReportCallback()]
    )
    trainer.fit(model, datamodule=datamodule)


ray_trainer = TorchTrainer(
    train_func,
    scaling_config=train.ScalingConfig(num_workers=2),
    run_config=train.RunConfig(
        checkpoint_config=train.CheckpointConfig(
            num_to_keep=2,
            checkpoint_score_attribute="mean_accuracy",
            checkpoint_score_order="max",
        ),
    ),
)

你可以随时从 result.checkpointresult.best_checkpoints 获取保存的检查点路径。

对于更高级的使用(例如,以不同频率报告,报告自定义检查点文件),您可以实现自己的自定义回调。以下是一个每3个周期报告一次检查点的简单示例:

import os
from tempfile import TemporaryDirectory

from pytorch_lightning.callbacks import Callback

import ray
import ray.train
from ray.train import Checkpoint


class CustomRayTrainReportCallback(Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        should_checkpoint = trainer.current_epoch % 3 == 0

        with TemporaryDirectory() as tmpdir:
            # Fetch metrics
            metrics = trainer.callback_metrics
            metrics = {k: v.item() for k, v in metrics.items()}

            # Add customized metrics
            metrics["epoch"] = trainer.current_epoch
            metrics["custom_metric"] = 123

            checkpoint = None
            global_rank = ray.train.get_context().get_world_rank() == 0
            if global_rank == 0 and should_checkpoint:
                # Save model checkpoint file to tmpdir
                ckpt_path = os.path.join(tmpdir, "ckpt.pt")
                trainer.save_checkpoint(ckpt_path, weights_only=False)

                checkpoint = Checkpoint.from_directory(tmpdir)

            # Report to train session
            ray.train.report(metrics=metrics, checkpoint=checkpoint)


Ray Train 利用 HuggingFace Transformers Trainer 的 Callback 接口来报告指标和检查点。

选项 1:使用 Ray Train 的默认报告回调

我们提供了一个简单的回调实现 RayTrainReportCallback,用于报告检查点的保存。您可以通过 save_strategysave_steps 更改检查点保存的频率。它会收集最新的记录指标,并与最新保存的检查点一起报告。

from transformers import TrainingArguments

from ray import train
from ray.train.huggingface.transformers import RayTrainReportCallback, prepare_trainer
from ray.train.torch import TorchTrainer


def train_func(config):
    ...

    # Configure logging, saving, evaluation strategies as usual.
    args = TrainingArguments(
        ...,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        logging_strategy="step",
    )

    trainer = transformers.Trainer(args, ...)

    # Add a report callback to transformers Trainer
    # =============================================
    trainer.add_callback(RayTrainReportCallback())
    trainer = prepare_trainer(trainer)

    trainer.train()


ray_trainer = TorchTrainer(
    train_func,
    run_config=train.RunConfig(
        checkpoint_config=train.CheckpointConfig(
            num_to_keep=3,
            checkpoint_score_attribute="eval_loss",  # The monitoring metric
            checkpoint_score_order="min",
        )
    ),
)

注意,RayTrainReportCallback 将最新的指标和检查点绑定在一起,因此用户可以正确配置 logging_strategysave_strategyevaluation_strategy,以确保监控指标与检查点保存的步骤同步记录。

例如,评估指标(本例中为 eval_loss)在评估期间会被记录。如果用户希望根据 eval_loss 保留最佳的3个检查点,他们应调整保存和评估的频率。以下是两个有效的配置示例:

args = TrainingArguments(
    ...,
    evaluation_strategy="epoch",
    save_strategy="epoch",
)

args = TrainingArguments(
    ...,
    evaluation_strategy="steps",
    save_strategy="steps",
    eval_steps=50,
    save_steps=100,
)

# And more ...

选项 2:实现您的自定义报告回调

如果你觉得 Ray Train 的默认 RayTrainReportCallback 对你的用例不够充分,你也可以自己实现一个回调!下面是一个收集最新指标并在保存检查点时报告的示例实现。

from ray import train

from transformers.trainer_callback import TrainerCallback


class MyTrainReportCallback(TrainerCallback):
    def __init__(self):
        super().__init__()
        self.metrics = {}

    def on_log(self, args, state, control, model=None, logs=None, **kwargs):
        """Log is called on evaluation step and logging step."""
        self.metrics.update(logs)

    def on_save(self, args, state, control, **kwargs):
        """Event called after a checkpoint save."""

        checkpoint = None
        if train.get_context().get_world_rank() == 0:
            # Build a Ray Train Checkpoint from the latest checkpoint
            checkpoint_path = transformers.trainer.get_last_checkpoint(args.output_dir)
            checkpoint = Checkpoint.from_directory(checkpoint_path)

        # Report to Ray Train with up-to-date metrics
        ray.train.report(metrics=self.metrics, checkpoint=checkpoint)

        # Clear the metrics buffer
        self.metrics = {}


你可以通过实现自己的 Transformers Trainer 回调来定制何时(on_saveon_epoch_endon_evaluate)以及报告什么(自定义指标和检查点文件)。

从多个工作节点保存检查点(分布式检查点)#

在模型并行训练策略中,每个工作节点只拥有完整模型的一部分,您可以从每个工作节点并行保存和报告检查点分片。

../../_images/persistent_storage_checkpoint.png

Ray Train 中的分布式检查点。每个工作节点独立地将自身的检查点分片上传到持久存储中。#

分布式检查点是进行模型并行训练(例如,DeepSpeed、FSDP、Megatron-LM)时保存检查点的最佳实践。

有两个主要的好处:

  1. 它更快,从而减少了空闲时间。 更快的检查点激励更频繁的检查点!

    每个工作节点可以并行上传其检查点分片,最大化集群的网络带宽。与单个节点上传大小为 M 的完整模型不同,集群将负载分布在 N 个节点上,每个节点上传大小为 M / N 的分片。

  2. 分布式检查点避免了将整个模型收集到单个工作者的CPU内存中。

    这个收集操作对执行检查点的worker提出了大量的CPU内存需求,并且是OOM错误的常见来源。

以下是使用 PyTorch 进行分布式检查点的示例:

from ray import train
from ray.train import Checkpoint
from ray.train.torch import TorchTrainer


def train_func(config):
    ...

    with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
        rank = train.get_context().get_world_rank()
        torch.save(
            ...,
            os.path.join(temp_checkpoint_dir, f"model-rank={rank}.pt"),
        )
        checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)

        train.report(metrics, checkpoint=checkpoint)


trainer = TorchTrainer(
    train_func,
    scaling_config=train.ScalingConfig(num_workers=2),
    run_config=train.RunConfig(storage_path="s3://bucket/"),
)
# The checkpoint in cloud storage will contain: model-rank=0.pt, model-rank=1.pt

备注

具有相同名称的检查点文件会在工作节点之间发生冲突。您可以通过为检查点文件添加特定于等级的后缀来解决这个问题。

请注意,文件名冲突不会导致错误,但最终保存的将是最后上传的版本。如果所有工作节点上的文件内容相同,这是可以接受的。

由DeepSpeed等框架提供的模型分片保存工具已经创建了特定等级的文件名,所以你通常不需要担心这个问题。

配置检查点#

Ray Train 通过 CheckpointConfig 提供了一些检查点配置选项。主要的配置是仅保留与某个指标相关的最佳 K 个检查点。性能较低的检查点会被删除以节省存储空间。默认情况下,所有检查点都会被保留。

from ray.train import RunConfig, CheckpointConfig

# Example 1: Only keep the 2 *most recent* checkpoints and delete the others.
run_config = RunConfig(checkpoint_config=CheckpointConfig(num_to_keep=2))


# Example 2: Only keep the 2 *best* checkpoints and delete the others.
run_config = RunConfig(
    checkpoint_config=CheckpointConfig(
        num_to_keep=2,
        # *Best* checkpoints are determined by these params:
        checkpoint_score_attribute="mean_accuracy",
        checkpoint_score_order="max",
    ),
    # This will store checkpoints on S3.
    storage_path="s3://remote-bucket/location",
)

备注

如果你想通过 CheckpointConfig 保存与某个指标相关的顶部 num_to_keep 个检查点,请确保该指标始终与检查点一起报告。

训练后使用检查点#

最新的保存检查点可以通过 Result.checkpoint 访问。

可以通过 Result.best_checkpoints 访问所有持久化的检查点列表。如果设置了 CheckpointConfig(num_to_keep),此列表将包含最佳的 num_to_keep 个检查点。

有关检查训练结果的完整指南,请参阅 检查训练结果

Checkpoint.as_directoryCheckpoint.to_directory 是与训练检查点交互的两个主要API:

from pathlib import Path

from ray.train import Checkpoint

# For demonstration, create a locally available directory with a `model.pt` file.
example_checkpoint_dir = Path("/tmp/test-checkpoint")
example_checkpoint_dir.mkdir()
example_checkpoint_dir.joinpath("model.pt").touch()

# Create the checkpoint, which is a reference to the directory.
checkpoint = Checkpoint.from_directory(example_checkpoint_dir)

# Inspect the checkpoint's contents with either `as_directory` or `to_directory`:
with checkpoint.as_directory() as checkpoint_dir:
    assert Path(checkpoint_dir).joinpath("model.pt").exists()

checkpoint_dir = checkpoint.to_directory()
assert Path(checkpoint_dir).joinpath("model.pt").exists()

对于 Lightning 和 Transformers,如果你在训练函数中使用默认的 RayTrainReportCallback 进行检查点保存,你可以按如下方式检索原始检查点文件:

# After training finished
checkpoint = result.checkpoint
with checkpoint.as_directory() as checkpoint_dir:
    lightning_checkpoint_path = f"{checkpoint_dir}/checkpoint.ckpt"
# After training finished
checkpoint = result.checkpoint
with checkpoint.as_directory() as checkpoint_dir:
    hf_checkpoint_path = f"{checkpoint_dir}/checkpoint/"

Restore training state from a checkpoint#

In order to enable fault tolerance, you should modify your training loop to restore training state from a Checkpoint.

要恢复的 检查点 可以通过 ray.train.get_checkpoint 在训练函数中访问。

通过 ray.train.get_checkpoint 返回的检查点有两种填充方式:

  1. 它可以自动填充为最新报告的检查点,例如在 自动故障恢复手动恢复 期间。

  2. 可以通过向 Ray Trainerresume_from_checkpoint 参数传递一个检查点来手动填充。这对于使用先前运行的检查点初始化新的训练运行非常有用。

import os
import tempfile

import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam

import ray.train.torch
from ray import train
from ray.train import Checkpoint, ScalingConfig
from ray.train.torch import TorchTrainer


def train_func(config):
    n = 100
    # create a toy dataset
    # data   : X - dim = (n, 4)
    # target : Y - dim = (n, 1)
    X = torch.Tensor(np.random.normal(0, 1, size=(n, 4)))
    Y = torch.Tensor(np.random.uniform(0, 1, size=(n, 1)))
    # toy neural network : 1-layer
    model = nn.Linear(4, 1)
    optimizer = Adam(model.parameters(), lr=3e-4)
    criterion = nn.MSELoss()

    # Wrap the model in DDP and move it to GPU.
    model = ray.train.torch.prepare_model(model)

    # ====== Resume training state from the checkpoint. ======
    start_epoch = 0
    checkpoint = train.get_checkpoint()
    if checkpoint:
        with checkpoint.as_directory() as checkpoint_dir:
            model_state_dict = torch.load(
                os.path.join(checkpoint_dir, "model.pt"),
                # map_location=...,  # Load onto a different device if needed.
            )
            model.module.load_state_dict(model_state_dict)
            optimizer.load_state_dict(
                torch.load(os.path.join(checkpoint_dir, "optimizer.pt"))
            )
            start_epoch = (
                torch.load(os.path.join(checkpoint_dir, "extra_state.pt"))["epoch"] + 1
            )
    # ========================================================

    for epoch in range(start_epoch, config["num_epochs"]):
        y = model.forward(X)
        loss = criterion(y, Y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        metrics = {"loss": loss.item()}

        with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
            checkpoint = None

            should_checkpoint = epoch % config.get("checkpoint_freq", 1) == 0
            # In standard DDP training, where the model is the same across all ranks,
            # only the global rank 0 worker needs to save and report the checkpoint
            if train.get_context().get_world_rank() == 0 and should_checkpoint:
                # === Make sure to save all state needed for resuming training ===
                torch.save(
                    model.module.state_dict(),  # NOTE: Unwrap the model.
                    os.path.join(temp_checkpoint_dir, "model.pt"),
                )
                torch.save(
                    optimizer.state_dict(),
                    os.path.join(temp_checkpoint_dir, "optimizer.pt"),
                )
                torch.save(
                    {"epoch": epoch},
                    os.path.join(temp_checkpoint_dir, "extra_state.pt"),
                )
                # ================================================================
                checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)

            train.report(metrics, checkpoint=checkpoint)

        if epoch == 1:
            raise RuntimeError("Intentional error to showcase restoration!")


trainer = TorchTrainer(
    train_func,
    train_loop_config={"num_epochs": 5},
    scaling_config=ScalingConfig(num_workers=2),
    run_config=train.RunConfig(failure_config=train.FailureConfig(max_failures=1)),
)
result = trainer.fit()

# Seed a training run with a checkpoint using `resume_from_checkpoint`
trainer = TorchTrainer(
    train_func,
    train_loop_config={"num_epochs": 5},
    scaling_config=ScalingConfig(num_workers=2),
    resume_from_checkpoint=result.checkpoint,
)
import os

from ray import train
from ray.train import Checkpoint
from ray.train.torch import TorchTrainer
from ray.train.lightning import RayTrainReportCallback


def train_func():
    model = MyLightningModule(...)
    datamodule = MyLightningDataModule(...)
    trainer = pl.Trainer(..., callbacks=[RayTrainReportCallback()])

    checkpoint = train.get_checkpoint()
    if checkpoint:
        with checkpoint.as_directory() as ckpt_dir:
            ckpt_path = os.path.join(ckpt_dir, "checkpoint.ckpt")
            trainer.fit(model, datamodule=datamodule, ckpt_path=ckpt_path)
    else:
        trainer.fit(model, datamodule=datamodule)


# Build a Ray Train Checkpoint
# Suppose we have a Lightning checkpoint at `s3://bucket/ckpt_dir/checkpoint.ckpt`
checkpoint = Checkpoint("s3://bucket/ckpt_dir")

# Resume training from checkpoint file
ray_trainer = TorchTrainer(
    train_func,
    scaling_config=train.ScalingConfig(num_workers=2),
    resume_from_checkpoint=checkpoint,
)

备注

在这些示例中,使用了 Checkpoint.as_directory 将检查点内容视为本地目录。

如果检查点指向一个本地目录,此方法仅返回本地目录路径,而不进行复制。

如果检查点指向一个远程目录,此方法将把检查点下载到本地临时目录,并返回该临时目录的路径。

如果在同一节点上的多个进程同时调用此方法, 只有单个进程会执行下载,而其他进程则等待下载完成。下载完成后,所有进程都会接收到相同的本地(临时)目录以供读取。

一旦所有进程完成与检查点的交互,临时目录将被清理。