ray.train.report#

ray.train.report(metrics: Dict, *, checkpoint: Checkpoint | None = None) None#

报告指标并可选择保存检查点。

如果提供了检查点,它将被 持久化到存储中

如果在多个分布式训练工作节点中调用此方法:

  • 只有 rank 0 工作节点报告的指标会被 Ray Train 跟踪。请参阅 指标日志记录指南

  • 只要有一个或多个工作节点报告了非None的检查点,就会注册一个检查点。请参阅 检查点指南

  • 多个工作者的检查点将被合并到持久存储中的一个目录中。请参阅 分布式检查点指南

备注

每次调用此方法都会自动增加底层 training_iteration 的数值。这个“迭代”的物理意义由用户根据他们调用 report 的频率来定义。它不一定对应一个 epoch。

警告

所有工作线程必须调用 ray.train.report 相同次数,以便 Ray Train 能够在工作线程之间正确同步训练状态。否则,您的训练将会挂起。

警告

此方法不会作为分布式训练工作者的屏障。工作者会上传他们的检查点,然后立即继续训练。如果你需要同步工作者,可以使用框架原生的屏障,例如 torch.distributed.barrier()

示例

import tempfile

from ray import train
from ray.train import Checkpoint
from ray.train.torch import TorchTrainer


def train_func(config):
    start_epoch = 0
    checkpoint = train.get_checkpoint()
    if checkpoint:
        with checkpoint.as_directory() as checkpoint_dir:
            # Load back training state
            ...

    for epoch in range(start_epoch, config.get("num_epochs", 10)):
        # Do training...

        metrics = {"loss": ...}

        with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
           # Save the checkpoint...
           # torch.save(...)

            checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)

            # Example: Only the rank 0 worker uploads the checkpoint.
            if ray.train.get_context().get_world_rank() == 0:
                train.report(metrics, checkpoint=checkpoint)
            else:
                train.report(metrics, checkpoint=None)

trainer = TorchTrainer(
    train_func, scaling_config=train.ScalingConfig(num_workers=2)
)
参数:
  • metrics – 你想要报告的指标。

  • checkpoint – 您想要报告的可选检查点。