ray.train.report#
- ray.train.report(metrics: Dict, *, checkpoint: Checkpoint | None = None) None #
报告指标并可选择保存检查点。
如果提供了检查点,它将被 持久化到存储中。
如果在多个分布式训练工作节点中调用此方法:
只有 rank 0 工作节点报告的指标会被 Ray Train 跟踪。请参阅 指标日志记录指南。
只要有一个或多个工作节点报告了非None的检查点,就会注册一个检查点。请参阅 检查点指南。
多个工作者的检查点将被合并到持久存储中的一个目录中。请参阅 分布式检查点指南。
备注
每次调用此方法都会自动增加底层
training_iteration
的数值。这个“迭代”的物理意义由用户根据他们调用report
的频率来定义。它不一定对应一个 epoch。警告
所有工作线程必须调用
ray.train.report
相同次数,以便 Ray Train 能够在工作线程之间正确同步训练状态。否则,您的训练将会挂起。警告
此方法不会作为分布式训练工作者的屏障。工作者会上传他们的检查点,然后立即继续训练。如果你需要同步工作者,可以使用框架原生的屏障,例如
torch.distributed.barrier()
。示例
import tempfile from ray import train from ray.train import Checkpoint from ray.train.torch import TorchTrainer def train_func(config): start_epoch = 0 checkpoint = train.get_checkpoint() if checkpoint: with checkpoint.as_directory() as checkpoint_dir: # Load back training state ... for epoch in range(start_epoch, config.get("num_epochs", 10)): # Do training... metrics = {"loss": ...} with tempfile.TemporaryDirectory() as temp_checkpoint_dir: # Save the checkpoint... # torch.save(...) checkpoint = Checkpoint.from_directory(temp_checkpoint_dir) # Example: Only the rank 0 worker uploads the checkpoint. if ray.train.get_context().get_world_rank() == 0: train.report(metrics, checkpoint=checkpoint) else: train.report(metrics, checkpoint=None) trainer = TorchTrainer( train_func, scaling_config=train.ScalingConfig(num_workers=2) )
- 参数:
metrics – 你想要报告的指标。
checkpoint – 您想要报告的可选检查点。